module Data.Array.Knead.Parameterized.Physical (
Phys.Array,
Phys.shape,
Phys.fromList,
feed,
the,
render,
renderShape,
mapAccumL,
foldOuterL,
scatter,
scatterMaybe,
permute,
) where
import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHull
import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Core as LLVM
import Foreign.Marshal.Utils (with, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (withForeignPtr, touchForeignPtr, )
import Foreign.Ptr (FunPtr, Ptr, )
import Control.Exception (bracket, )
import Control.Monad.HT ((<=<), )
import Control.Applicative (liftA2, )
import Data.Tuple.HT (mapFst, )
import Data.Word (Word32, )
feed ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
MultiValueMemory.C a) =>
Param.T p (Phys.Array sh a) -> Sym.Array p sh a
feed arr =
Param.withMulti (fmap Phys.shape arr) $ \getShape valueShape ->
Sym.Array
(\p ->
case mapFst valueShape $ MultiValue.unzip p of
(sh, MultiValue.Cons ptr) ->
Core.Array (Expr.lift0 sh) $
Memory.load <=< getElementPtr sh ptr)
(\p ->
case Phys.buffer $ Param.get arr p of
fptr ->
withForeignPtr fptr $ \ptr ->
return (fptr, (getShape p, MultiValueMemory.castStructPtr ptr)))
touchForeignPtr
type Importer f = FunPtr f -> f
foreign import ccall safe "dynamic" callThe ::
Importer (Ptr param -> Ptr am -> IO ())
the ::
(Shape.Scalar z, MultiValueMemory.C a, Storable a) =>
Sym.Array p z a -> IO (p -> IO a)
the (Sym.Array arr create delete) = do
func <-
compile "the" $
Code.createFunction callThe "eval" $
\paramPtr resultPtr -> do
param <- Memory.load paramPtr
case arr param of
Core.Array z code ->
code (Shape.zeroIndex z) >>= flip Memory.store resultPtr
LLVM.ret ()
return $ \p ->
bracket (create p) (delete . fst) $ \(_ctx, param) ->
with param $ \pptr ->
alloca $ \aptr ->
func (MultiValueMemory.castStructPtr pptr) (MultiValueMemory.castStructPtr aptr) >>
peek aptr
foreign import ccall safe "dynamic" callShaper ::
Importer (Ptr param -> Ptr shape -> IO Word32)
renderShape ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
Sym.Array p sh a -> IO (p -> IO (sh, Word32))
renderShape (Sym.Array arr create delete) = do
fsh <-
compile "renderShape" $
Code.createFunction callShaper "shape" $
\paramPtr resultPtr -> do
param <- Memory.load paramPtr
case arr param of
Core.Array esh _code -> do
sh <- unExp esh
MultiValueMemory.store sh resultPtr
Shape.sizeCode sh >>= LLVM.ret
return $ \p ->
bracket (create p) (delete . fst) $ \(_ctx, param) ->
alloca $ \shptr ->
with param $ \pptr -> do
let lpptr = MultiValueMemory.castStructPtr pptr
let lshptr = MultiValueMemory.castStructPtr shptr
n <- fsh lpptr lshptr
sh <- peek shptr
return (sh, n)
render ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a))
render = PhysHull.render . Sym.arrayHull
mapAccumL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
MultiValue.C acc,
Storable a, MultiValueMemory.C a,
Storable b, MultiValueMemory.C b) =>
(Exp acc -> Exp a -> Exp (acc,b)) ->
Sym.Array p sh acc ->
Sym.Array p (sh, n) a ->
IO (p -> IO (Phys.Array (sh,n) b))
mapAccumL f arrInit arrMap =
PhysHull.mapAccumL $
liftA2 (PhysHull.MapAccumL f)
(Sym.arrayHull arrInit)
(Sym.arrayHull arrMap)
foldOuterL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp b -> Exp a) ->
Sym.Array p sh a ->
Sym.Array p (n,sh) b ->
IO (p -> IO (Phys.Array sh a))
foldOuterL f arrInit arrMap =
PhysHull.foldOuterL $
liftA2 (PhysHull.FoldOuterL f)
(Sym.arrayHull arrInit)
(Sym.arrayHull arrMap)
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array p sh1 a ->
Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a))
scatter accum arrBase arrMap =
PhysHull.scatter $
liftA2 (PhysHull.Scatter accum)
(Sym.arrayHull arrBase)
(Sym.arrayHull arrMap)
scatterMaybe ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array p sh1 a ->
Sym.Array p sh0 (Maybe (ix1, a)) -> IO (p -> IO (Phys.Array sh1 a))
scatterMaybe accum arrBase arrMap =
PhysHull.scatterMaybe $
liftA2 (PhysHull.ScatterMaybe accum)
(Sym.arrayHull arrBase)
(Sym.arrayHull arrMap)
permute ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array p sh1 a ->
(Exp ix0 -> Exp ix1) ->
Sym.Array p sh0 a ->
IO (p -> IO (Phys.Array sh1 a))
permute accum deflt ixmap input =
scatter accum deflt
(Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)