module Data.Array.Knead.Simple.Private where
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp(Exp), )
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Core as LLVM
import qualified Control.Category as Cat
import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )
import Prelude hiding (id, map, zipWith, replicate, )
type Val = MultiValue.T
type Code r a = LLVM.CodeGenFunction r (Val a)
data Array sh a =
Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a)
shape :: Array sh a -> Exp sh
shape (Array sh _) = sh
(!) ::
(Shape.C sh, Shape.Index sh ~ ix) =>
Array sh a -> Exp ix -> Exp a
(!) (Array _ code) (Exp ix) = Exp (code =<< ix)
the :: (Shape.Scalar sh) => Array sh a -> Exp a
the (Array z code) = Exp (code $ Shape.zeroIndex z)
fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a
fromScalar = fill Shape.scalar
fill :: Exp sh -> Exp a -> Array sh a
fill sh (Exp code) = Array sh (\_z -> code)
class C array where
lift0 :: Array sh a -> array sh a
lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b
lift2 ::
(Array sha a -> Array shb b -> Array shc c) ->
array sha a -> array shb b -> array shc c
instance C Array where
lift0 = Cat.id
lift1 = Cat.id
lift2 = Cat.id
gather ::
(C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
MultiValue.C a) =>
array sh1 ix0 ->
array sh0 a ->
array sh1 a
gather =
lift2 $ \(Array sh1 f) (Array _sh0 code) ->
Array sh1 (code <=< f)
backpermute2 ::
(C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
Exp sh ->
(Exp ix -> Exp ix0) ->
(Exp ix -> Exp ix1) ->
(Exp a -> Exp b -> Exp c) ->
array sh0 a -> array sh1 b -> array sh c
backpermute2 sh projectIndex0 projectIndex1 f =
lift2 $ \(Array _sha codeA) (Array _shb codeB) ->
Array sh
(\ix ->
Monad.liftJoin2 (Expr.unliftM2 f)
(codeA =<< Expr.unliftM1 projectIndex0 ix)
(codeB =<< Expr.unliftM1 projectIndex1 ix))
id ::
(C array, Shape.C sh, Shape.Index sh ~ ix) =>
Exp sh -> array sh ix
id sh = lift0 $ Array sh return
map ::
(C array, Shape.C sh) =>
(Exp a -> Exp b) ->
array sh a -> array sh b
map f =
lift1 $ \(Array sh code) ->
Array sh (Expr.unliftM1 f <=< code)
mapWithIndex ::
(C array, Shape.C sh, Shape.Index sh ~ ix) =>
(Exp ix -> Exp a -> Exp b) ->
array sh a -> array sh b
mapWithIndex f =
lift1 $ \(Array sh code) ->
Array sh (\ix -> Expr.unliftM2 f ix =<< code ix)
fold1Code ::
(Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Exp sh ->
(Val ix -> Code r a) ->
Code r a
fold1Code f (Exp nc) code = do
n <- nc
fmap Maybe.fromJust $
Shape.loop
(\i0 macc0 -> do
a <- code i0
acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
return $ Maybe.just acc1)
n Maybe.nothing
fold1 ::
(C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
array (sh0, sh1) a -> array sh0 a
fold1 f =
lift1 $ \(Array shs code) ->
case Expr.unzip shs of
(sh, s) -> Array sh $ fold1Code f s . MultiValue.curry code
fold1All ::
(Shape.C sh, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Array sh a -> Exp a
fold1All f (Array sh code) = Exp (fold1Code f sh code)
findAllCode ::
(Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) =>
(Exp a -> Exp Bool) ->
Exp sh ->
(Val ix -> Code r a) ->
Code r (Maybe a)
findAllCode p (Exp sh) code = do
n <- sh
finalFound <-
Iter.mapWhileState_
(\a _found -> do
MultiValue.Cons b <- Expr.unliftM1 p a
notb <- LLVM.inv b
return (notb, Maybe.fromBool b a))
(Iter.mapM code $ Shape.iterator n)
Maybe.nothing
Maybe.run finalFound
(return MultiValue.nothing)
(return . MultiValue.just)
findAll ::
(Shape.C sh, MultiValue.C a) =>
(Exp a -> Exp Bool) ->
Array sh a -> Exp (Maybe a)
findAll p (Array sh code) = Exp (findAllCode p sh code)
class Process proc where
infixl 3 $:.
($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1
($:.) = flip ($)