{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Iterator where

import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core
   (CodeGenFunction, Value, value, valueOf,
    CmpRet, CmpResult, IsInteger, IsType, IsConst, )

import Foreign.Ptr (Ptr, )

import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import qualified Control.Functor.HT as FuncHT
import Control.Monad (void, (<=<), )
import Control.Applicative (Applicative, liftA2, (<$>), (<$), )

import Data.Tuple.HT (mapFst, mapSnd, )

import Prelude hiding (iterate, takeWhile, take, mapM)


{- |
Simulates a non-strict list.
-}
data T r a =
   forall s. (Phi s) =>
   Cons s (forall z. (Phi z) => s -> MaybeCont.T r z (a,s))

mapM_ :: (a -> CodeGenFunction r ()) -> T r a -> CodeGenFunction r ()
mapM_ f (Cons s next) =
   void $
   C.loopWithExit s
      (\s0 ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, s0))
            (\(a,s1) -> (valueOf True, s1) <$ f a))
      return

mapState_ ::
   (Phi t) =>
   (a -> t -> CodeGenFunction r t) ->
   T r a -> t -> CodeGenFunction r t
mapState_ f (Cons s next) t =
   snd <$>
   C.loopWithExit (s,t)
      (\(s0,t0) ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, (s0,t0)))
            (\(a,s1) -> (\t1 -> (valueOf True, (s1,t1))) <$> f a t0))
      return

mapStateM_ ::
   (Phi t) =>
   (a -> MS.StateT t (CodeGenFunction r) ()) ->
   T r a -> MS.StateT t (CodeGenFunction r) ()
mapStateM_ f xs =
   MS.StateT $ \t ->
      (,) () <$> mapState_ (\a t0 -> snd <$> MS.runStateT (f a) t0) xs t


mapWhileState_ ::
   (Phi t) =>
   (a -> t -> CodeGenFunction r (Value Bool, t)) ->
   T r a -> t -> CodeGenFunction r t
mapWhileState_ f (Cons s next) t =
   snd <$>
   C.loopWithExit (s,t)
      (\(s0,t0) ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, (s0,t0)))
            (\(a,s1) -> (\(b,t1) -> (b, (s1,t1))) <$> f a t0))
      return


empty :: T r a
empty = Cons () (\() -> MaybeCont.nothing)

singleton :: a -> T r a
singleton a =
   Cons
      (valueOf True)
      (\running -> MaybeCont.guard running >> return (a, valueOf False))


instance Functor (T r) where
   fmap f (Cons s next) = Cons s (\s0 -> mapFst f <$> next s0)

{- |
@ZipList@ semantics
-}
instance Applicative (T r) where
   pure a = Cons () (\() -> return (a,()))
   Cons fs fnext <*> Cons as anext =
      Cons (fs,as)
         (\(fs0,as0) -> do
            (f,fs1) <- fnext fs0
            (a,as1) <- anext as0
            return (f a, (fs1,as1)))


{-
On the one hand,
I did not want to name it @map@ because it differs from @fmap@.
On the other hand, @mapM@ does not fit very well
because the result is not in the CodeGenFunction monad.
-}
mapM :: (a -> CodeGenFunction r b) -> T r a -> T r b
mapM f (Cons s next) = Cons s (MaybeCont.lift . FuncHT.mapFst f <=< next)

mapMaybe ::
   (Phi b, Class.Undefined b) =>
   (a -> CodeGenFunction r (Maybe.T b)) -> T r a -> T r b
mapMaybe f = catMaybes . mapM f

catMaybes :: (Phi a, Class.Undefined a) => T r (Maybe.T a) -> T r a
catMaybes (Cons s next) =
   Cons s
      (\s0 ->
         MaybeCont.fromMaybe $
         fmap (\(ma,s2) -> fmap (flip (,) s2) ma) $
         C.loopWithExit s0
            (\s1 ->
               MaybeCont.resolve (next s1)
                  (return (valueOf False, (Maybe.nothing, s1)))
                  (\(ma,s2) ->
                     Maybe.run ma
                        (return (valueOf True, (Maybe.nothing, s2)))
                        (\a -> return (valueOf False, (Maybe.just a, s2)))))
            (return . snd))

takeWhile :: (a -> CodeGenFunction r (Value Bool)) -> T r a -> T r a
takeWhile p (Cons s next) =
   Cons s
      (\s0 -> do
         (a,s1) <- next s0
         MaybeCont.guard =<< MaybeCont.lift (p a)
         return (a,s1))

{- |
Attention:
This always performs one function call more than necessary.
I.e. if 'f' reads from or writes to memory
make sure that accessing one more pointer is legal.
-}
iterate :: (Phi a) => (a -> CodeGenFunction r a) -> a -> T r a
iterate f a = Cons a (\a0 -> MaybeCont.lift $ fmap ((,) a0) $ f a0)

{- |
This is MaybeCont.toMaybe' where @('Undefined' a)@ constraint
is replaced by a custom value.
This way, we do not need 'Undefined' constraint in 'T'.
On the other hand, an LLVM-undefined value would enable more LLVM optimizations.
-}
maybeFromCont ::
   a -> MaybeCont.T r (Maybe.T a) a -> CodeGenFunction r (Maybe.T a)
maybeFromCont undef (MaybeCont.Cons m) =
   m (return $ Maybe.Cons (valueOf False) undef) (return . Maybe.just)

cartesianAux ::
   (Phi a, Phi b, Class.Undefined a, Class.Undefined b) =>
   T r a -> T r b -> T r (Maybe.T (a,b))
cartesianAux (Cons sa nextA) (Cons sb nextB) =
   Cons (Maybe.nothing,sa,sb)
      (\(ma0,sa0,sb0) -> do
         (a1,sa1) <-
            MaybeCont.fromMaybe $
            Maybe.run ma0
               (maybeFromCont (Class.undefTuple,sa0) $ nextA sa0)
               (\a0 -> return (Maybe.just (a0,sa0)))
         MaybeCont.lift $
            MaybeCont.resolve (nextB sb0)
               (return (Maybe.nothing,(Maybe.nothing,sa1,sb)))
               (\(b1,sb1) ->
                  return (Maybe.just (a1,b1), (Maybe.just a1, sa1, sb1))))


-- * helper functions

cartesian ::
   (Phi a, Phi b, Class.Undefined a, Class.Undefined b) =>
   T r a -> T r b -> T r (a,b)
cartesian as bs = catMaybes $ cartesianAux as bs

countDown ::
   (Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> T r (Value i)
countDown len =
   takeWhile (A.cmp LLVM.CmpLT (value LLVM.zero)) $ iterate A.dec len

take ::
   (Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> T r a -> T r a
take len xs = liftA2 const xs (countDown len)

arrayPtrs :: (IsType a) => Value (Ptr a) -> T r (Value (Ptr a))
arrayPtrs = iterate A.advanceArrayElementPtr


-- * examples

fixedLengthLoop ::
   (Phi s,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> s ->
   (s -> CodeGenFunction r s) ->
   CodeGenFunction r s
fixedLengthLoop len start loopBody =
   mapState_ (const loopBody) (countDown len) start

arrayLoop ::
   (Phi a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr b) -> a ->
   (Value (Ptr b) -> a -> CodeGenFunction r a) ->
   CodeGenFunction r a
arrayLoop len ptr start loopBody =
   mapState_ loopBody (take len $ arrayPtrs ptr) start

arrayLoopWithExit ::
   (Phi s, IsType a,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr a) -> s ->
   (Value (Ptr a) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoopWithExit len ptr0 start loopBody = do
   (i, end) <-
      mapWhileState_
         (\(i,ptr) (_i,s) -> mapSnd ((,) i) <$> loopBody ptr s)
         (liftA2 (,) (countDown len) (arrayPtrs ptr0))
         (len,start)
   pos <- A.sub len i
   return (pos, end)

arrayLoop2 ::
   (Phi s, IsType a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr a) -> Value (Ptr b) -> s ->
   (Value (Ptr a) -> Value (Ptr b) -> s -> CodeGenFunction r s) ->
   CodeGenFunction r s
arrayLoop2 len ptrA ptrB start loopBody =
   mapState_ (uncurry loopBody)
      (take len $ liftA2 (,) (arrayPtrs ptrA) (arrayPtrs ptrB)) start