module Data.Array.Knead.Expression where
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom, )
import qualified Control.Monad.HT as Monad
import qualified Data.Tuple.HT as TupleHT
import qualified Data.Tuple as Tuple
import Data.Complex (Complex((:+)))
import Data.Bool8 (Bool8)
import Prelude
hiding (fst, snd, min, max, zip, unzip, zip3, unzip3,
curry, uncurry, pi, maybe)
newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)}
class Value val where
lift0 :: MultiValue.T a -> val a
lift1 ::
(MultiValue.T a -> MultiValue.T b) ->
val a -> val b
lift2 ::
(MultiValue.T a -> MultiValue.T b -> MultiValue.T c) ->
val a -> val b -> val c
lift3 ::
(MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) ->
val a -> val b -> val c -> val d
lift4 ::
(MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d -> MultiValue.T e) ->
val a -> val b -> val c -> val d -> val e
instance Value MultiValue.T where
lift0 = id
lift1 = id
lift2 = id
lift3 = id
lift4 = id
instance Value Exp where
lift0 a = Exp (return a)
lift1 f (Exp a) = Exp (Monad.lift f a)
lift2 f (Exp a) (Exp b) = Exp (Monad.lift2 f a b)
lift3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.lift3 f a b c)
lift4 f (Exp a) (Exp b) (Exp c) (Exp d) = Exp (Monad.lift4 f a b c d)
liftM ::
(forall r.
MultiValue.T a ->
LLVM.CodeGenFunction r (MultiValue.T b)) ->
(Exp a -> Exp b)
liftM f (Exp a) = Exp (f =<< a)
liftM2 ::
(forall r.
MultiValue.T a -> MultiValue.T b ->
LLVM.CodeGenFunction r (MultiValue.T c)) ->
(Exp a -> Exp b -> Exp c)
liftM2 f (Exp a) (Exp b) = Exp (Monad.liftJoin2 f a b)
liftM3 ::
(forall r.
MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
LLVM.CodeGenFunction r (MultiValue.T d)) ->
(Exp a -> Exp b -> Exp c -> Exp d)
liftM3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.liftJoin3 f a b c)
unliftM1 ::
(Exp a -> Exp b) ->
MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)
unliftM1 f ix = unExp (f (lift0 ix))
unliftM2 ::
(Exp a -> Exp b -> Exp c) ->
MultiValue.T a -> MultiValue.T b ->
LLVM.CodeGenFunction r (MultiValue.T c)
unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx))
unliftM3 ::
(Exp a -> Exp b -> Exp c -> Exp d) ->
MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
LLVM.CodeGenFunction r (MultiValue.T d)
unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx))
zip :: (Value val) => val a -> val b -> val (a, b)
zip = lift2 MultiValue.zip
zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c)
zip3 = lift3 MultiValue.zip3
zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d)
zip4 = lift4 MultiValue.zip4
unzip :: (Value val) => val (a, b) -> (val a, val b)
unzip ab =
(lift1 MultiValue.fst ab, lift1 MultiValue.snd ab)
unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c)
unzip3 abc =
(lift1 MultiValue.fst3 abc,
lift1 MultiValue.snd3 abc,
lift1 MultiValue.thd3 abc)
unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d)
unzip4 abcd =
(lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd,
lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd,
lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd,
lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd)
fst :: (Value val) => val (a, b) -> val a
fst = lift1 MultiValue.fst
snd :: (Value val) => val (a, b) -> val b
snd = lift1 MultiValue.snd
mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
mapFst f = liftM (MultiValue.mapFstF (unliftM1 f))
mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
mapSnd f = liftM (MultiValue.mapSndF (unliftM1 f))
swap :: (Value val) => val (a, b) -> val (b, a)
swap = lift1 MultiValue.swap
curry :: (Exp (a,b) -> c) -> (Exp a -> Exp b -> c)
curry f = Tuple.curry (f . Tuple.uncurry zip)
uncurry :: (Exp a -> Exp b -> c) -> (Exp (a,b) -> c)
uncurry f = Tuple.uncurry f . unzip
fst3 :: (Value val) => val (a,b,c) -> val a
fst3 = lift1 MultiValue.fst3
snd3 :: (Value val) => val (a,b,c) -> val b
snd3 = lift1 MultiValue.snd3
thd3 :: (Value val) => val (a,b,c) -> val c
thd3 = lift1 MultiValue.thd3
mapFst3 :: (Exp a0 -> Exp a1) -> Exp (a0,b,c) -> Exp (a1,b,c)
mapFst3 f = liftM (MultiValue.mapFst3F (unliftM1 f))
mapSnd3 :: (Exp b0 -> Exp b1) -> Exp (a,b0,c) -> Exp (a,b1,c)
mapSnd3 f = liftM (MultiValue.mapSnd3F (unliftM1 f))
mapThd3 :: (Exp c0 -> Exp c1) -> Exp (a,b,c0) -> Exp (a,b,c1)
mapThd3 f = liftM (MultiValue.mapThd3F (unliftM1 f))
modifyMultiValue ::
(Value val,
MultiValue.Compose a,
MultiValue.Decompose pattern,
MultiValue.PatternTuple pattern ~ tuple) =>
pattern ->
(Decomposed MultiValue.T pattern -> a) ->
val tuple -> val (MultiValue.Composed a)
modifyMultiValue p f = lift1 $ MultiValue.modify p f
modifyMultiValue2 ::
(Value val,
MultiValue.Compose a,
MultiValue.Decompose patternA,
MultiValue.Decompose patternB,
MultiValue.PatternTuple patternA ~ tupleA,
MultiValue.PatternTuple patternB ~ tupleB) =>
patternA ->
patternB ->
(Decomposed MultiValue.T patternA ->
Decomposed MultiValue.T patternB -> a) ->
val tupleA -> val tupleB -> val (MultiValue.Composed a)
modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f
modifyMultiValueM ::
(MultiValue.Compose a,
MultiValue.Decompose pattern,
MultiValue.PatternTuple pattern ~ tuple) =>
pattern ->
(forall r.
Decomposed MultiValue.T pattern ->
LLVM.CodeGenFunction r a) ->
Exp tuple -> Exp (MultiValue.Composed a)
modifyMultiValueM p f = liftM (MultiValue.modifyF p f)
modifyMultiValueM2 ::
(MultiValue.Compose a,
MultiValue.Decompose patternA,
MultiValue.Decompose patternB,
MultiValue.PatternTuple patternA ~ tupleA,
MultiValue.PatternTuple patternB ~ tupleB) =>
patternA ->
patternB ->
(forall r.
Decomposed MultiValue.T patternA ->
Decomposed MultiValue.T patternB ->
LLVM.CodeGenFunction r a) ->
Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a)
modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f)
class Compose multituple where
type Composed multituple
compose :: multituple -> Exp (Composed multituple)
class
(Composed (Decomposed Exp pattern) ~ PatternTuple pattern) =>
Decompose pattern where
decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern
modify ::
(Compose a, Decompose pattern) =>
pattern ->
(Decomposed Exp pattern -> a) ->
Exp (PatternTuple pattern) -> Exp (Composed a)
modify p f = compose . f . decompose p
modify2 ::
(Compose a, Decompose patternA, Decompose patternB) =>
patternA ->
patternB ->
(Decomposed Exp patternA -> Decomposed Exp patternB -> a) ->
Exp (PatternTuple patternA) ->
Exp (PatternTuple patternB) -> Exp (Composed a)
modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b)
instance Compose (Exp a) where
type Composed (Exp a) = a
compose = id
instance Decompose (Atom a) where
decompose _ = id
instance Compose () where
type Composed () = ()
compose = cons
instance Decompose () where
decompose _ _ = ()
instance (Compose a, Compose b) => Compose (a,b) where
type Composed (a,b) = (Composed a, Composed b)
compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose)
instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where
decompose (pa,pb) =
TupleHT.mapPair (decompose pa, decompose pb) . unzip
instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where
type Composed (a,b,c) = (Composed a, Composed b, Composed c)
compose =
TupleHT.uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose)
instance
(Decompose pa, Decompose pb, Decompose pc) =>
Decompose (pa,pb,pc) where
decompose (pa,pb,pc) =
TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3
instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where
type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d)
compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d)
instance
(Decompose pa, Decompose pb, Decompose pc, Decompose pd) =>
Decompose (pa,pb,pc,pd) where
decompose (pa,pb,pc,pd) x =
case unzip4 x of
(a,b,c,d) ->
(decompose pa a, decompose pb b, decompose pc c, decompose pd d)
instance (Compose a) => Compose (Complex a) where
type Composed (Complex a) = Complex (Composed a)
compose (r:+i) = consComplex (compose r) (compose i)
instance (Decompose p) => Decompose (Complex p) where
decompose (pr:+pi) =
Tuple.uncurry (:+) .
TupleHT.mapPair (decompose pr, decompose pi) . deconsComplex
consComplex :: Exp a -> Exp a -> Exp (Complex a)
consComplex = lift2 MultiValue.consComplex
deconsComplex :: Exp (Complex a) -> (Exp a, Exp a)
deconsComplex c = (lift1 MultiValue.realPart c, lift1 MultiValue.imagPart c)
cons :: (MultiValue.C a) => a -> Exp a
cons = lift0 . MultiValue.cons
unit :: Exp ()
unit = cons ()
zero :: (MultiValue.C a) => Exp a
zero = lift0 MultiValue.zero
add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
add = liftM2 MultiValue.add
sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
sub = liftM2 MultiValue.sub
mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a
mul = liftM2 MultiValue.mul
sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a
sqr = liftM $ \x -> MultiValue.mul x x
sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a
sqrt = liftM MultiValue.sqrt
idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
idiv = liftM2 MultiValue.idiv
irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
irem = liftM2 MultiValue.irem
shl :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shl = liftM2 MultiValue.shl
shr :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shr = liftM2 MultiValue.shr
fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a
fromInteger' = lift0 . MultiValue.fromInteger'
fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a
fromRational' = lift0 . MultiValue.fromRational'
boolPFrom8 :: Exp Bool8 -> Exp Bool
boolPFrom8 = lift1 MultiValue.boolPFrom8
bool8FromP :: Exp Bool -> Exp Bool8
bool8FromP = lift1 MultiValue.bool8FromP
intFromBool8 :: (MultiValue.NativeInteger i ir) => Exp Bool8 -> Exp i
intFromBool8 = liftM MultiValue.intFromBool8
floatFromBool8 :: (MultiValue.NativeFloating a ar) => Exp Bool8 -> Exp a
floatFromBool8 = liftM MultiValue.floatFromBool8
minBound, maxBound :: (MultiValue.Bounded a) => Exp a
minBound = lift0 MultiValue.minBound
maxBound = lift0 MultiValue.maxBound
cmp ::
(MultiValue.Comparison a) =>
LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool
cmp ord = liftM2 (MultiValue.cmp ord)
infix 4 ==*, /=*, <*, <=*, >*, >=*
(==*), (/=*), (<*), (>=*), (>*), (<=*) ::
(MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool
(==*) = cmp LLVM.CmpEQ
(/=*) = cmp LLVM.CmpNE
(<*) = cmp LLVM.CmpLT
(>=*) = cmp LLVM.CmpGE
(>*) = cmp LLVM.CmpGT
(<=*) = cmp LLVM.CmpLE
min, max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a
min = liftM2 A.min
max = liftM2 A.max
true, false :: Exp Bool
true = cons True
false = cons False
infixr 3 &&*
(&&*) :: Exp Bool -> Exp Bool -> Exp Bool
(&&*) = liftM2 MultiValue.and
infixr 2 ||*
(||*) :: Exp Bool -> Exp Bool -> Exp Bool
(||*) = liftM2 MultiValue.or
not :: Exp Bool -> Exp Bool
not = liftM MultiValue.inv
select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a
select = liftM3 MultiValue.select
ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse ec ex ey =
Exp (do
MultiValue.Cons c <- unExp ec
C.ifThenElse c (unExp ex) (unExp ey))
complement :: (MultiValue.Logic a) => Exp a -> Exp a
complement = liftM MultiValue.inv
infixl 7 .&.*
(.&.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.&.*) = liftM2 MultiValue.and
infixl 5 .|.*
(.|.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.|.*) = liftM2 MultiValue.or
infixl 6 `xor`
xor :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
xor = liftM2 MultiValue.xor
toMaybe :: Exp Bool -> Exp a -> Exp (Maybe a)
toMaybe = lift2 MultiValue.toMaybe
maybe :: (MultiValue.C b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b
maybe n j = liftM $ \m -> do
let (MultiValue.Cons b, a) = MultiValue.splitMaybe m
C.ifThenElse b (unliftM1 j a) (unExp n)
instance
(MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) =>
Num (Exp a) where
fromInteger n = lift0 (MultiValue.fromInteger' n)
(+) = liftM2 MultiValue.add
() = liftM2 MultiValue.sub
negate = liftM MultiValue.neg
(*) = liftM2 MultiValue.mul
abs = liftM MultiValue.abs
signum = liftM MultiValue.signum
instance
(MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) =>
Fractional (Exp a) where
fromRational n = lift0 (MultiValue.fromRational' n)
(/) = liftM2 MultiValue.fdiv