{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Data.Fold (
Fold(..), runFold,
) where
import Data.Array.Accelerate.Classes.Floating as A
import Data.Array.Accelerate.Classes.Fractional as A
import Data.Array.Accelerate.Classes.Num as A
import Data.Array.Accelerate.Data.Monoid
import Data.Array.Accelerate.Language as A
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Smart ( Acc, Exp, constant )
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Prelude hiding ( sum, product, length )
import Control.Applicative as P
import qualified Prelude as P
data Fold i o where
Fold :: (Elt w, Monoid (Exp w))
=> (i -> Exp w)
-> (Exp w -> o)
-> Fold i o
runFold
:: (Shape sh, Elt i, Elt o)
=> Fold (Exp i) (Exp o)
-> Acc (Array (sh:.Int) i)
-> Acc (Array sh o)
runFold :: Fold (Exp i) (Exp o)
-> Acc (Array (sh :. Int) i) -> Acc (Array sh o)
runFold (Fold Exp i -> Exp w
tally Exp w -> Exp o
summarise) Acc (Array (sh :. Int) i)
is
= (Exp w -> Exp o) -> Acc (Array sh w) -> Acc (Array sh o)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map Exp w -> Exp o
summarise
(Acc (Array sh w) -> Acc (Array sh o))
-> Acc (Array sh w) -> Acc (Array sh o)
forall a b. (a -> b) -> a -> b
$ (Exp w -> Exp w -> Exp w)
-> Exp w -> Acc (Array (sh :. Int) w) -> Acc (Array sh w)
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array (sh :. Int) a) -> Acc (Array sh a)
A.fold Exp w -> Exp w -> Exp w
forall a. Monoid a => a -> a -> a
mappend Exp w
forall a. Monoid a => a
mempty
(Acc (Array (sh :. Int) w) -> Acc (Array sh w))
-> Acc (Array (sh :. Int) w) -> Acc (Array sh w)
forall a b. (a -> b) -> a -> b
$ (Exp i -> Exp w)
-> Acc (Array (sh :. Int) i) -> Acc (Array (sh :. Int) w)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map Exp i -> Exp w
tally Acc (Array (sh :. Int) i)
is
instance P.Functor (Fold i) where
fmap :: (a -> b) -> Fold i a -> Fold i b
fmap a -> b
k (Fold i -> Exp w
tally Exp w -> a
summarise) = (i -> Exp w) -> (Exp w -> b) -> Fold i b
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold i -> Exp w
tally (a -> b
k (a -> b) -> (Exp w -> a) -> Exp w -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp w -> a
summarise)
instance P.Applicative (Fold i) where
pure :: a -> Fold i a
pure a
o = (i -> Exp ()) -> (Exp () -> a) -> Fold i a
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold (\i
_ -> () -> Exp ()
forall e. (HasCallStack, Elt e) => e -> Exp e
constant ()) (\Exp ()
_ -> a
o)
Fold i -> Exp w
tF Exp w -> a -> b
sF <*> :: Fold i (a -> b) -> Fold i a -> Fold i b
<*> Fold i -> Exp w
tX Exp w -> a
sX = (i -> Exp (w, w)) -> (Exp (w, w) -> b) -> Fold i b
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold i -> Exp (w, w)
i -> Exp (Plain (Exp w, Exp w))
tally Exp (w, w) -> b
summarise
where
tally :: i -> Exp (Plain (Exp w, Exp w))
tally i
i = (Exp w, Exp w) -> Exp (Plain (Exp w, Exp w))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (i -> Exp w
tF i
i, i -> Exp w
tX i
i)
summarise :: Exp (w, w) -> b
summarise Exp (w, w)
t = let (Exp w
mF, Exp w
mX) = Exp (Plain (Exp w, Exp w)) -> (Exp w, Exp w)
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (w, w)
Exp (Plain (Exp w, Exp w))
t
in Exp w -> a -> b
sF Exp w
mF (Exp w -> a
sX Exp w
mX)
instance A.Num b => P.Num (Fold a (Exp b)) where
+ :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(+) = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Num a => a -> a -> a
(+)
(-) = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
* :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(*) = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Num a => a -> a -> a
(*)
negate :: Fold a (Exp b) -> Fold a (Exp b)
negate = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
negate
abs :: Fold a (Exp b) -> Fold a (Exp b)
abs = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
abs
signum :: Fold a (Exp b) -> Fold a (Exp b)
signum = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
signum
fromInteger :: Integer -> Fold a (Exp b)
fromInteger Integer
n = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer -> Exp b
forall a. Num a => Integer -> a
A.fromInteger Integer
n)
instance A.Fractional b => P.Fractional (Fold a (Exp b)) where
/ :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(/) = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Fractional a => a -> a -> a
(/)
recip :: Fold a (Exp b) -> Fold a (Exp b)
recip = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Fractional a => a -> a
recip
fromRational :: Rational -> Fold a (Exp b)
fromRational Rational
n = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Rational -> Exp b
forall a. Fractional a => Rational -> a
A.fromRational Rational
n)
instance A.Floating b => P.Floating (Fold a (Exp b)) where
pi :: Fold a (Exp b)
pi = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp b
forall a. Floating a => a
pi
sin :: Fold a (Exp b) -> Fold a (Exp b)
sin = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sin
cos :: Fold a (Exp b) -> Fold a (Exp b)
cos = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
cos
tan :: Fold a (Exp b) -> Fold a (Exp b)
tan = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
tan
asin :: Fold a (Exp b) -> Fold a (Exp b)
asin = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
asin
acos :: Fold a (Exp b) -> Fold a (Exp b)
acos = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
acos
atan :: Fold a (Exp b) -> Fold a (Exp b)
atan = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
atan
sinh :: Fold a (Exp b) -> Fold a (Exp b)
sinh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sinh
cosh :: Fold a (Exp b) -> Fold a (Exp b)
cosh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
cosh
tanh :: Fold a (Exp b) -> Fold a (Exp b)
tanh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
tanh
asinh :: Fold a (Exp b) -> Fold a (Exp b)
asinh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
asinh
acosh :: Fold a (Exp b) -> Fold a (Exp b)
acosh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
acosh
atanh :: Fold a (Exp b) -> Fold a (Exp b)
atanh = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
atanh
exp :: Fold a (Exp b) -> Fold a (Exp b)
exp = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
exp
sqrt :: Fold a (Exp b) -> Fold a (Exp b)
sqrt = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sqrt
log :: Fold a (Exp b) -> Fold a (Exp b)
log = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
log
** :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(**) = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Floating a => a -> a -> a
(**)
logBase :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
logBase = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Floating a => a -> a -> a
logBase