module Data.Array.Knead.Index.Linear (
C(switch),
switchInt,
intersect,
value,
constant,
paramWith,
tunnel,
flattenIndex,
peek,
poke,
computeSize,
Struct,
T(..),
Z(Z), z,
(:.)((:.)),
Shape, shape,
Index, index,
cons, (#:.),
head,
tail,
switchR,
loadMultiValue,
storeMultiValue,
) where
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Index.Linear.Int as Index
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom, )
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Marshal.Array (advancePtr, )
import Foreign.Ptr (Ptr, castPtr, )
import Control.Monad (liftM2, )
import Data.Word (Word32, )
import Prelude hiding (min, head, tail, )
class C ix where
switch ::
f Z ->
(forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) ->
f ix
instance C Z where
switch x _ = x
instance (C ix0, Index.Single i) => C (ix0 :. i) where
switch _ x = x
newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)}
switchInt ::
(C ix) =>
f Z ->
(forall ix0. (C ix0) => f (ix0 :. Index.Int)) ->
f ix
switchInt z0 cons0 =
switch z0
(runSwitchInt $ Index.switchSingle (SwitchInt cons0))
newtype Op2 tag sh = Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)}
intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh)
intersect =
runOp2 $
switchInt
(Op2 $ \z0 _ -> z0)
(Op2 $
switchR $ \is i ->
switchR $ \js j ->
intersect is js #:. Expr.min i j)
_value :: (C sh, MultiValue.C sh) => sh -> Exp sh
_value = Expr.lift0 . MultiValue.cons
newtype MakeValue val tag sh = MakeValue {runMakeValue :: T tag sh -> val (T tag sh)}
value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh)
value =
runMakeValue $
switchInt
(MakeValue $ \(Cons Z) -> z)
(MakeValue $ \(Cons (t:.h)) ->
value (Cons t) #:. Expr.lift0 (MultiValue.cons h))
paramWith ::
(C sh, Expr.Value val) =>
Param.T p (T tag sh) ->
(forall parameters.
(St.Storable parameters,
MultiValueMemory.C parameters) =>
(p -> parameters) ->
(MultiValue.T parameters -> val (T tag sh)) ->
a) ->
a
paramWith p f =
case tunnel p of
Param.Tunnel get val -> f get (Expr.lift0 . val)
tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh)
tunnel p =
case structFieldsPropF p of
StructFieldsProp -> Param.tunnel value p
data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp
_structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh
_structFieldsProp _p = structFieldsRec
structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh
structFieldsPropF _p = structFieldsRec
withStructFieldsPropFF ::
(C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh))
withStructFieldsPropFF f = f structFieldsRec
structFieldsRec :: (C sh) => StructFieldsProp sh
structFieldsRec =
switchInt
StructFieldsProp
(succStructFieldsProp structFieldsRec)
succStructFieldsProp ::
StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int)
succStructFieldsProp StructFieldsProp = StructFieldsProp
data Z = Z
deriving (Eq, Ord, Read, Show)
infixl 3 :., #:.
data tail :. head = !tail :. !head
deriving (Eq, Ord, Read, Show)
newtype T tag sh = Cons {decons :: sh}
data ShapeTag
data IndexTag
type Shape = T ShapeTag
type Index = T IndexTag
shape :: sh -> Shape sh
shape = Cons
index :: ix -> Index ix
index = Cons
(#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
(#:.) = cons
cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
cons =
Expr.lift2 $
\(MultiValue.Cons t) (MultiValue.Cons h) ->
MultiValue.Cons (t,h)
z :: (Expr.Value val) => val (T tag Z)
z = Expr.lift0 $ MultiValue.Cons ()
head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i
head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h
tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh)
tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t
switchR ::
Expr.Value val =>
(val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a
switchR f ix = f (tail ix) (head ix)
instance (tag ~ ShapeTag, sh ~ Z) => Shape.Scalar (T tag sh) where
scalar = Expr.lift0 $ MultiValue.Cons ()
zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()
type family PatternTuple pattern
type family Decomposed (f :: * -> *) tag pattern
type instance PatternTuple (sh:.s) =
PatternTuple sh :. MultiValue.PatternTuple s
type instance Decomposed f tag (sh:.s) =
Decomposed f tag sh :. MultiValue.Decomposed f s
type instance PatternTuple (Atom sh) = sh
type instance Decomposed f tag (Atom sh) = f (T tag sh)
class
(Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) =>
Decompose tag pattern where
decompose ::
T tag pattern -> Exp (T tag (PatternTuple pattern)) ->
Decomposed Exp tag pattern
instance Decompose tag (Atom sh) where
decompose (Cons _atom) x = x
instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where
decompose (Cons (psh:.ps)) x =
decompose (Cons psh) (tail x) :. Expr.decompose ps (head x)
type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh)
type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh
type family Unwrap sh
type instance Unwrap (T tag sh) = sh
type family Tag sh
type instance Tag (T tag sh) = tag
instance
(Expr.Compose sh,
Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)),
Expr.Compose s) =>
Expr.Compose (sh :. s) where
type Composed (sh :. s) =
T (Tag (Expr.Composed sh))
(Unwrap (Expr.Composed sh) :. Expr.Composed s)
compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)
instance (Decompose tag sh) => Expr.Decompose (T tag sh) where
decompose = decompose
instance (C sh) => St.Storable (T tag sh) where
sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Word32)
alignment (Cons _sh) = St.alignment (0::Word32)
poke ptr = poke (castPtr ptr) . decons
peek = fmap Cons . peek . castPtr
type family Repr (f :: * -> *) sh
type instance Repr f Z = ()
type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head)
instance (C sh) => MultiValue.C (T tag sh) where
type Repr f (T tag sh) = Repr f sh
cons = value
undef = constant $ MultiValue.undef
zero = constant $ MultiValue.zero
addPhis = addPhis
phis = phis
instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where
type Index (T tag sh) = Index sh
size = fromIntegral . size . decons
sizeCode = computeSize
intersectCode = Expr.unliftM2 intersect
flattenIndexRec sh ix =
liftM2 (,)
(computeSize sh)
(flattenIndex sh ix)
loop = loop
type family Struct sh
type instance Struct Z = ()
type instance Struct (sh :. Index.Int) = (Word32, Struct sh)
instance
(C sh, LLVM.StructFields (Struct sh)) =>
MultiValueMemory.C (T tag sh) where
type Struct (T tag sh) = LLVM.Struct (Struct sh)
load = loadMultiValue
store = storeMultiValue
loadMultiValue ::
(C sh) =>
LLVM.Value (Ptr (LLVM.Struct (Struct sh))) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
loadMultiValue ptr =
withStructFieldsPropFF $ \StructFieldsProp ->
load =<< castPtrValue ptr
storeMultiValue ::
(C sh) =>
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r ()
storeMultiValue x ptr =
case structFieldsPropF x of
StructFieldsProp -> store x =<< castPtrValue ptr
newtype FlattenIndex r sh =
FlattenIndex {
runFlattenIndex ::
MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
}
flattenIndex ::
(C sh) =>
MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
flattenIndex =
runFlattenIndex $
switchInt
(FlattenIndex $ \_zerosh _zeroix -> return A.zero)
(FlattenIndex $
switchR $ \sh (MultiValue.Cons s) ->
switchR $ \ix (MultiValue.Cons i) ->
A.add i =<< A.mul s =<< flattenIndex sh ix)
newtype Rank sh = Rank {runRank :: sh -> Int}
rank :: (C sh) => sh -> Int
rank =
runRank $
switch
(Rank $ const 0)
(Rank $ succ . rank . (\(sh :. _s) -> sh))
newtype Peek sh = Peek {runPeek :: Ptr Word32 -> IO sh}
peek :: (C sh) => Ptr Word32 -> IO sh
peek =
runPeek $
switchInt
(Peek $ const $ return Z)
(Peek $ \ptr -> do
h <- St.peek ptr
t <- peek $ advancePtr ptr 1
return (t :. Index.Int h))
newtype Poke sh = Poke {runPoke :: Ptr Word32 -> sh -> IO ()}
poke :: (C sh) => Ptr Word32 -> sh -> IO ()
poke =
runPoke $
switchInt
(Poke $ const $ const $ return ())
(Poke $ \ptr (sh :. Index.Int i) -> do
St.poke ptr i
poke (advancePtr ptr 1) sh)
castPtrValue ::
(LLVM.StructFields sh) =>
LLVM.Value (Ptr (LLVM.Struct sh)) ->
LLVM.CodeGenFunction r (LLVM.Value (Ptr Word32))
castPtrValue = LLVM.bitcast
newtype Load r tag sh =
Load {
runLoad ::
LLVM.Value (Ptr Word32) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
}
load ::
(C sh) =>
LLVM.Value (Ptr Word32) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
load =
runLoad $
switchInt
(Load $ const $ return z)
(Load $ \ptr -> do
h <- LLVM.load ptr
t <- load =<< A.advanceArrayElementPtr ptr
return (t #:. MultiValue.Cons h))
newtype Store r tag sh =
Store {
runStore ::
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr Word32) ->
LLVM.CodeGenFunction r ()
}
store ::
(C sh) =>
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr Word32) ->
LLVM.CodeGenFunction r ()
store =
runStore $
switchInt
(Store $ \_z _ptr -> return ())
(Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do
LLVM.store k ptr
store sh =<< A.advanceArrayElementPtr ptr)
newtype Size sh =
Size {
runSize :: sh -> Word32
}
size :: (C sh) => sh -> Word32
size =
runSize $
switchInt
(Size $ \_z -> 1)
(Size $ \(sh :. Index.Int k) -> k * size sh)
newtype ComputeSize r sh =
ComputeSize {
runComputeSize ::
MultiValue.T (Shape sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
}
computeSize ::
(C sh) =>
MultiValue.T (Shape sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
computeSize =
runComputeSize $
switchInt
(ComputeSize $ \_z -> return A.one)
(ComputeSize $ switchR $ \sh (MultiValue.Cons k) ->
A.mul k =<< computeSize sh)
newtype
Constant val tag sh =
Constant {getConstant :: val Index.Int -> val (T tag sh)}
constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh)
constant =
getConstant $
switchInt
(Constant $ const z)
(Constant $ \x -> constant x #:. x)
newtype AddPhis r tag sh =
AddPhis {
runAddPhis ::
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r ()
}
addPhis ::
(C sh) =>
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r ()
addPhis =
runAddPhis $
switchInt
(AddPhis $ \_ _ _ -> return ())
(AddPhis $ \bb ->
switchR $ \hx tx ->
switchR $ \hy ty ->
MultiValue.addPhis bb tx ty >>
addPhis bb hx hy)
newtype Phis r tag sh =
Phis {
runPhis ::
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
}
phis ::
(C sh) =>
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
phis =
runPhis $
switchInt
(Phis $ \_ -> return)
(Phis $ \bb ->
switchR $ \h t ->
liftM2 (#:.)
(phis bb h)
(MultiValue.phis bb t))
newtype Loop r state sh =
Loop {
runLoop ::
(MultiValue.T (Index sh) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape sh) ->
state ->
LLVM.CodeGenFunction r state
}
loop ::
(C sh, Loop.Phi state) =>
(MultiValue.T (Index sh) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape sh) ->
state ->
LLVM.CodeGenFunction r state
loop =
runLoop $
switchInt
(Loop $ \code _z -> code z)
(Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
loop
(\ix ptrStart ->
fmap fst $
C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
liftM2 (,)
(code (ix #:. MultiValue.Cons k) ptr)
(A.inc k))
sh)