{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Expression where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom, )

import qualified Control.Monad.HT as Monad

import qualified Data.Tuple.HT as TupleHT
import qualified Data.Tuple as Tuple
import Data.Complex (Complex((:+)))
import Data.Bool8 (Bool8)

import Prelude
   hiding (fst, snd, min, max, zip, unzip, zip3, unzip3,
           curry, uncurry, pi, maybe)


newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)}


class Value val where
   lift0 :: MultiValue.T a -> val a
   lift1 ::
      (MultiValue.T a -> MultiValue.T b) ->
      val a -> val b
   lift2 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c) ->
      val a -> val b -> val c
   lift3 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) ->
      val a -> val b -> val c -> val d
   lift4 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d -> MultiValue.T e) ->
      val a -> val b -> val c -> val d -> val e

instance Value MultiValue.T where
   lift0 = id
   lift1 = id
   lift2 = id
   lift3 = id
   lift4 = id

instance Value Exp where
   lift0 a = Exp (return a)
   lift1 f (Exp a) = Exp (Monad.lift f a)
   lift2 f (Exp a) (Exp b) = Exp (Monad.lift2 f a b)
   lift3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.lift3 f a b c)
   lift4 f (Exp a) (Exp b) (Exp c) (Exp d) = Exp (Monad.lift4 f a b c d)


liftM ::
   (forall r.
    MultiValue.T a ->
    LLVM.CodeGenFunction r (MultiValue.T b)) ->
   (Exp a -> Exp b)
liftM f (Exp a) = Exp (f =<< a)

liftM2 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b ->
    LLVM.CodeGenFunction r (MultiValue.T c)) ->
   (Exp a -> Exp b -> Exp c)
liftM2 f (Exp a) (Exp b) = Exp (Monad.liftJoin2 f a b)

liftM3 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
    LLVM.CodeGenFunction r (MultiValue.T d)) ->
   (Exp a -> Exp b -> Exp c -> Exp d)
liftM3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.liftJoin3 f a b c)


unliftM1 ::
   (Exp a -> Exp b) ->
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)
unliftM1 f ix = unExp (f (lift0 ix))

unliftM2 ::
   (Exp a -> Exp b -> Exp c) ->
   MultiValue.T a -> MultiValue.T b ->
   LLVM.CodeGenFunction r (MultiValue.T c)
unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx))

unliftM3 ::
   (Exp a -> Exp b -> Exp c -> Exp d) ->
   MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
   LLVM.CodeGenFunction r (MultiValue.T d)
unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx))



zip :: (Value val) => val a -> val b -> val (a, b)
zip = lift2 MultiValue.zip

zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c)
zip3 = lift3 MultiValue.zip3

zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d)
zip4 = lift4 MultiValue.zip4

unzip :: (Value val) => val (a, b) -> (val a, val b)
unzip ab =
   (lift1 MultiValue.fst ab, lift1 MultiValue.snd ab)

unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c)
unzip3 abc =
   (lift1 MultiValue.fst3 abc,
    lift1 MultiValue.snd3 abc,
    lift1 MultiValue.thd3 abc)

unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d)
unzip4 abcd =
   (lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd,
    lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd,
    lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd,
    lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd)


fst :: (Value val) => val (a, b) -> val a
fst = lift1 MultiValue.fst

snd :: (Value val) => val (a, b) -> val b
snd = lift1 MultiValue.snd

mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
mapFst f = liftM (MultiValue.mapFstF (unliftM1 f))

mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
mapSnd f = liftM (MultiValue.mapSndF (unliftM1 f))

swap :: (Value val) => val (a, b) -> val (b, a)
swap = lift1 MultiValue.swap

curry :: (Exp (a,b) -> c) -> (Exp a -> Exp b -> c)
curry f = Tuple.curry (f . Tuple.uncurry zip)

uncurry :: (Exp a -> Exp b -> c) -> (Exp (a,b) -> c)
uncurry f = Tuple.uncurry f . unzip


fst3 :: (Value val) => val (a,b,c) -> val a
fst3 = lift1 MultiValue.fst3

snd3 :: (Value val) => val (a,b,c) -> val b
snd3 = lift1 MultiValue.snd3

thd3 :: (Value val) => val (a,b,c) -> val c
thd3 = lift1 MultiValue.thd3

mapFst3 :: (Exp a0 -> Exp a1) -> Exp (a0,b,c) -> Exp (a1,b,c)
mapFst3 f = liftM (MultiValue.mapFst3F (unliftM1 f))

mapSnd3 :: (Exp b0 -> Exp b1) -> Exp (a,b0,c) -> Exp (a,b1,c)
mapSnd3 f = liftM (MultiValue.mapSnd3F (unliftM1 f))

mapThd3 :: (Exp c0 -> Exp c1) -> Exp (a,b,c0) -> Exp (a,b,c1)
mapThd3 f = liftM (MultiValue.mapThd3F (unliftM1 f))


modifyMultiValue ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (Decomposed MultiValue.T pattern -> a) ->
   val tuple -> val (MultiValue.Composed a)
modifyMultiValue p f = lift1 $ MultiValue.modify p f

modifyMultiValue2 ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB -> a) ->
   val tupleA -> val tupleB -> val (MultiValue.Composed a)
modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f

modifyMultiValueM ::
   (MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (forall r.
    Decomposed MultiValue.T pattern ->
    LLVM.CodeGenFunction r a) ->
   Exp tuple -> Exp (MultiValue.Composed a)
modifyMultiValueM p f = liftM (MultiValue.modifyF p f)

modifyMultiValueM2 ::
   (MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (forall r.
    Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB ->
    LLVM.CodeGenFunction r a) ->
   Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a)
modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f)


class Compose multituple where
   type Composed multituple
   {- |
   A nested 'zip'.
   -}
   compose :: multituple -> Exp (Composed multituple)

class
   (Composed (Decomposed Exp pattern) ~ PatternTuple pattern) =>
      Decompose pattern where
   {- |
   Analogous to 'MultiValue.decompose'.
   -}
   decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern


{- |
Analogus to 'MultiValue.modifyMultiValue'.
-}
modify ::
   (Compose a, Decompose pattern) =>
   pattern ->
   (Decomposed Exp pattern -> a) ->
   Exp (PatternTuple pattern) -> Exp (Composed a)
modify p f = compose . f . decompose p

modify2 ::
   (Compose a, Decompose patternA, Decompose patternB) =>
   patternA ->
   patternB ->
   (Decomposed Exp patternA -> Decomposed Exp patternB -> a) ->
   Exp (PatternTuple patternA) ->
   Exp (PatternTuple patternB) -> Exp (Composed a)
modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b)



instance Compose (Exp a) where
   type Composed (Exp a) = a
   compose = id

instance Decompose (Atom a) where
   decompose _ = id



instance Compose () where
   type Composed () = ()
   compose = cons

instance Decompose () where
   decompose _ _ = ()


instance (Compose a, Compose b) => Compose (a,b) where
   type Composed (a,b) = (Composed a, Composed b)
   compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose)

instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where
   decompose (pa,pb) =
      TupleHT.mapPair (decompose pa, decompose pb) . unzip


instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where
   type Composed (a,b,c) = (Composed a, Composed b, Composed c)
   compose =
      TupleHT.uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose)

instance
   (Decompose pa, Decompose pb, Decompose pc) =>
      Decompose (pa,pb,pc) where
   decompose (pa,pb,pc) =
      TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3


instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where
   type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d)
   compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d)

instance
   (Decompose pa, Decompose pb, Decompose pc, Decompose pd) =>
      Decompose (pa,pb,pc,pd) where
   decompose (pa,pb,pc,pd) x =
      case unzip4 x of
         (a,b,c,d) ->
            (decompose pa a, decompose pb b, decompose pc c, decompose pd d)


instance (Compose a) => Compose (Complex a) where
   type Composed (Complex a) = Complex (Composed a)
   compose (r:+i) = consComplex (compose r) (compose i)

instance (Decompose p) => Decompose (Complex p) where
   decompose (pr:+pi) =
      Tuple.uncurry (:+) .
      TupleHT.mapPair (decompose pr, decompose pi) . deconsComplex

{- |
You can construct complex numbers this way,
but they will not make you happy,
because the numeric operations require a RealFloat instance
that we could only provide with lots of undefined methods
(also in its superclasses).
You may either define your own arithmetic
or use the NumericPrelude type classes.
-}
consComplex :: Exp a -> Exp a -> Exp (Complex a)
consComplex = lift2 MultiValue.consComplex

deconsComplex :: Exp (Complex a) -> (Exp a, Exp a)
deconsComplex c = (lift1 MultiValue.realPart c, lift1 MultiValue.imagPart c)



cons :: (MultiValue.C a) => a -> Exp a
cons = lift0 . MultiValue.cons

unit :: Exp ()
unit = cons ()

zero :: (MultiValue.C a) => Exp a
zero = lift0 MultiValue.zero

add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
add = liftM2 MultiValue.add

sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
sub = liftM2 MultiValue.sub

mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a
mul = liftM2 MultiValue.mul

sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a
sqr = liftM $ \x -> MultiValue.mul x x

sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a
sqrt = liftM MultiValue.sqrt

idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
idiv = liftM2 MultiValue.idiv

irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
irem = liftM2 MultiValue.irem

shl :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shl = liftM2 MultiValue.shl

shr :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shr = liftM2 MultiValue.shr

fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a
fromInteger' = lift0 . MultiValue.fromInteger'

fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a
fromRational' = lift0 . MultiValue.fromRational'


boolPFrom8 :: Exp Bool8 -> Exp Bool
boolPFrom8 = lift1 MultiValue.boolPFrom8

bool8FromP :: Exp Bool -> Exp Bool8
bool8FromP = lift1 MultiValue.bool8FromP

intFromBool8 :: (MultiValue.NativeInteger i ir) => Exp Bool8 -> Exp i
intFromBool8 = liftM MultiValue.intFromBool8

floatFromBool8 :: (MultiValue.NativeFloating a ar) => Exp Bool8 -> Exp a
floatFromBool8 = liftM MultiValue.floatFromBool8


minBound, maxBound :: (MultiValue.Bounded a) => Exp a
minBound = lift0 MultiValue.minBound
maxBound = lift0 MultiValue.maxBound


cmp ::
   (MultiValue.Comparison a) =>
   LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool
cmp ord = liftM2 (MultiValue.cmp ord)

infix 4 ==*, /=*, <*, <=*, >*, >=*

(==*), (/=*), (<*), (>=*), (>*), (<=*) ::
   (MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool
(==*) = cmp LLVM.CmpEQ
(/=*) = cmp LLVM.CmpNE
(<*)  = cmp LLVM.CmpLT
(>=*) = cmp LLVM.CmpGE
(>*)  = cmp LLVM.CmpGT
(<=*) = cmp LLVM.CmpLE


min, max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a
min = liftM2 A.min
max = liftM2 A.max


true, false :: Exp Bool
true = cons True
false = cons False

infixr 3 &&*
(&&*) :: Exp Bool -> Exp Bool -> Exp Bool
(&&*) = liftM2 MultiValue.and

infixr 2 ||*
(||*) :: Exp Bool -> Exp Bool -> Exp Bool
(||*) = liftM2 MultiValue.or

not :: Exp Bool -> Exp Bool
not = liftM MultiValue.inv

{- |
Like 'ifThenElse' but computes both alternative expressions
and then uses LLVM's efficient @select@ instruction.
-}
select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a
select = liftM3 MultiValue.select

ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse ec ex ey =
   Exp (do
      MultiValue.Cons c <- unExp ec
      C.ifThenElse c (unExp ex) (unExp ey))


complement :: (MultiValue.Logic a) => Exp a -> Exp a
complement = liftM MultiValue.inv

infixl 7 .&.*
(.&.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.&.*) = liftM2 MultiValue.and

infixl 5 .|.*
(.|.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.|.*) = liftM2 MultiValue.or

infixl 6 `xor`
xor :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
xor = liftM2 MultiValue.xor


toMaybe :: Exp Bool -> Exp a -> Exp (Maybe a)
toMaybe = lift2 MultiValue.toMaybe

maybe :: (MultiValue.C b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b
maybe n j = liftM $ \m -> do
   let (MultiValue.Cons b, a) = MultiValue.splitMaybe m
   C.ifThenElse b (unliftM1 j a) (unExp n)


instance
   (MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) =>
      Num (Exp a) where
   fromInteger n = lift0 (MultiValue.fromInteger' n)
   (+) = liftM2 MultiValue.add
   (-) = liftM2 MultiValue.sub
   negate = liftM MultiValue.neg
   (*) = liftM2 MultiValue.mul
   abs = liftM MultiValue.abs
   signum = liftM MultiValue.signum

instance
   (MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) =>
      Fractional (Exp a) where
   fromRational n = lift0 (MultiValue.fromRational' n)
   (/) = liftM2 MultiValue.fdiv