module Futhark.CodeGen.ImpGen.Multicore.SegScan
( compileSegScan,
)
where
import Control.Monad
import Data.List (zip4)
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Util.IntegralExp (quot, rem)
import Prelude hiding (quot, rem)
compileSegScan ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.Code
compileSegScan :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
compileSegScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks
| [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
nonsegmentedScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks
| Bool
otherwise =
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
segmentedScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody
xParams, yParams :: SegBinOp MCMem -> [LParam MCMem]
xParams :: SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan)) (LambdaT MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan))
yParams :: SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan)) (LambdaT MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan))
lamBody :: SegBinOp MCMem -> Body MCMem
lamBody :: SegBinOp MCMem -> Body MCMem
lamBody = LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem)
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> Body MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda
resultArrays :: String -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
resultArrays :: String -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
resultArrays String
s [SegBinOp MCMem]
segops =
[SegBinOp MCMem]
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
segops ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]])
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ LambdaT MCMem
lam [SubExp]
_ Shape
shape) ->
[Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (LambdaT MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT MCMem
lam) ((Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
full_shape :: Shape
full_shape = Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
s PrimType
pt Shape
full_shape Space
DefaultSpace
nonsegmentedScan ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.Code
nonsegmentedScan :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
nonsegmentedScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody TV Int32
nsubtasks = do
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"nonsegmented segScan" Maybe Exp
forall a. Maybe a
Nothing
ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage1 Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody
let nsubtasks' :: TExp Int32
nsubtasks' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
nsubtasks
TExp Bool
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
nsubtasks' TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
1) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
[SegBinOp MCMem]
scan_ops2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
scan_ops
Pattern MCMem
-> TV Int32
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage2 Pattern MCMem
pat TV Int32
nsubtasks SegSpace
space [SegBinOp MCMem]
scan_ops2 KernelBody MCMem
kbody
[SegBinOp MCMem]
scan_ops3 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
scan_ops
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage3 Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops3 KernelBody MCMem
kbody
scanStage1 ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
MulticoreGen ()
scanStage1 :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage1 Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
scan_ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
per_scan_res :: [[KernelResult]]
per_scan_res = [SegBinOp MCMem] -> [KernelResult] -> [[KernelResult]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops [KernelResult]
all_scan_res
per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
TV Int64
iter <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"iter" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
([[VName]]
local_accs, Code
prebody) <- MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code))
-> MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall a b. (a -> b) -> a -> b
$ do
Maybe (Exp MCMem) -> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> ImpM MCMem HostEnv Multicore ())
-> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOp MCMem -> [Param LParamMem])
-> [SegBinOp MCMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp MCMem]
scan_ops
[SegBinOp MCMem]
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
scan_ops ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]])
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \SegBinOp MCMem
scan_op -> do
let shape :: Shape
shape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op
ts :: [Type]
ts = LambdaT MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT MCMem -> [Type]) -> LambdaT MCMem -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan_op
[(Param LParamMem, SubExp, Type)]
-> ((Param LParamMem, SubExp, Type)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param LParamMem]
-> [SubExp] -> [Type] -> [(Param LParamMem, SubExp, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) [Type]
ts) (((Param LParamMem, SubExp, Type)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> ((Param LParamMem, SubExp, Type)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne, Type
t) -> do
VName
acc <-
case Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
[] -> VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM MCMem HostEnv Multicore VName)
-> VName -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p
[SubExp]
_ -> do
let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"local_acc" PrimType
pt (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []
VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc
Code
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
(VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
iter
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"stage 1 scan body" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write mapped values results to memory" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElemT LParamMem]
map_arrs = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
scan_ops) ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LParamMem
pat
(PatElemT LParamMem
-> KernelResult -> ImpM MCMem HostEnv Multicore ())
-> [PatElemT LParamMem]
-> [KernelResult]
-> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElem MCMem -> KernelResult -> ImpM MCMem HostEnv Multicore ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_arrs [KernelResult]
map_res
[([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])]
-> (([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp MCMem]
-> [[KernelResult]]
-> [[VName]]
-> [([PatElemT LParamMem], SegBinOp MCMem, [KernelResult],
[VName])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops [[KernelResult]]
per_scan_res [[VName]]
local_accs) ((([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> (([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes, SegBinOp MCMem
scan_op, [KernelResult]
scan_res, [VName]
acc) ->
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
acc) (((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc') ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc') [TExp Int64]
vec_is
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read next values" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) [TExp Int64]
vec_is
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(VName, PatElemT LParamMem, SubExp)]
-> ((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [PatElemT LParamMem]
-> [SubExp]
-> [(VName, PatElemT LParamMem, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
acc [PatElemT LParamMem]
pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op)) (((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
\(VName
acc', PatElemT LParamMem
pe, SubExp
se) -> do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc' [TExp Int64]
vec_is SubExp
se []
[Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
prebody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
body) (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
iter])
let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"scan_stage_1" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
iter) (Code
body_allocs Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
prebody) Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
scanStage2 ::
Pattern MCMem ->
TV Int32 ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
MulticoreGen ()
scanStage2 :: Pattern MCMem
-> TV Int32
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage2 Pattern MCMem
pat TV Int32
nsubtasks SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"nonsegmentedScan stage 2" Maybe Exp
forall a. Maybe a
Nothing
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
nsubtasks' :: TExp Int32
nsubtasks' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
nsubtasks
Maybe (Exp MCMem) -> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> ImpM MCMem HostEnv Multicore ())
-> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOp MCMem -> [Param LParamMem])
-> [SegBinOp MCMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp MCMem]
scan_ops
TV Int64
offset <- String -> TExp Int64 -> ImpM MCMem HostEnv Multicore (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"offset" (TExp Int64
0 :: Imp.TExp Int64)
let offset' :: TExp Int64
offset' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offset
TV Int64
offset_index <- String -> TExp Int64 -> ImpM MCMem HostEnv Multicore (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"offset_index" (TExp Int64
0 :: Imp.TExp Int64)
let offset_index' :: TExp Int64
offset_index' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offset_index
let iter_pr_subtask :: TExp Int64
iter_pr_subtask = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
nsubtasks'
remainder :: TExp Int64
remainder = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
nsubtasks'
[[VName]]
accs <- String -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
resultArrays String
"scan_stage_2_accum" [SegBinOp MCMem]
scan_ops
[(SegBinOp MCMem, [VName])]
-> ((SegBinOp MCMem, [VName]) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp MCMem] -> [[VName]] -> [(SegBinOp MCMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp MCMem]
scan_ops [[VName]]
accs) (((SegBinOp MCMem, [VName]) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((SegBinOp MCMem, [VName]) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
scan_op, [VName]
acc) ->
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
acc ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) (((VName, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((VName, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(VName
acc', SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc' [TExp Int64]
vec_is SubExp
ne []
String
-> TExp Int32
-> (TExp Int32 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (TExp Int32
nsubtasks' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> (TExp Int32 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TV Int64
offset TV Int64 -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64
iter_pr_subtask
TExp Bool
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
remainder) (TV Int64
offset TV Int64 -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64
offset' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1)
TV Int64
offset_index TV Int64 -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64
offset_index' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
offset'
(VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
offset_index'
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[([PatElemT LParamMem], SegBinOp MCMem, [VName])]
-> (([PatElemT LParamMem], SegBinOp MCMem, [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp MCMem]
-> [[VName]]
-> [([PatElemT LParamMem], SegBinOp MCMem, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops [[VName]]
accs) ((([PatElemT LParamMem], SegBinOp MCMem, [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> (([PatElemT LParamMem], SegBinOp MCMem, [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes, SegBinOp MCMem
scan_op, [VName]
acc) ->
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read carry in" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
acc) (((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc') ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc') [TExp Int64]
vec_is
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read next values" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((TExp Int64
offset_index' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1) TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(VName, PatElemT LParamMem, SubExp)]
-> ((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [PatElemT LParamMem]
-> [SubExp]
-> [(VName, PatElemT LParamMem, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
acc [PatElemT LParamMem]
pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op)) (((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((VName, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
\(VName
acc', PatElemT LParamMem
pe, SubExp
se) -> do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((TExp Int64
offset_index' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1) TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is) SubExp
se []
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc' [TExp Int64]
vec_is SubExp
se []
scanStage3 ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
MulticoreGen ()
scanStage3 :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
scanStage3 Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
all_scan_res :: [KernelResult]
all_scan_res = Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
take ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
scan_ops) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
per_scan_res :: [[KernelResult]]
per_scan_res = [SegBinOp MCMem] -> [KernelResult] -> [[KernelResult]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops [KernelResult]
all_scan_res
per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
TV Int64
iter <- String -> TExp Int64 -> ImpM MCMem HostEnv Multicore (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"iter" (TExp Int64
0 :: Imp.TExp Int64)
let iter' :: TExp Int64
iter' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
iter
([[VName]]
local_accs, Code
prebody) <- MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code))
-> MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall a b. (a -> b) -> a -> b
$ do
Maybe (Exp MCMem) -> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> ImpM MCMem HostEnv Multicore ())
-> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOp MCMem -> [Param LParamMem])
-> [SegBinOp MCMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp MCMem]
scan_ops
[(SegBinOp MCMem, [PatElemT LParamMem])]
-> ((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegBinOp MCMem]
-> [[PatElemT LParamMem]]
-> [(SegBinOp MCMem, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp MCMem]
scan_ops [[PatElemT LParamMem]]
per_scan_pes) (((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]])
-> ((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
scan_op, [PatElemT LParamMem]
pes) -> do
let shape :: Shape
shape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op
ts :: [Type]
ts = LambdaT MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT MCMem -> [Type]) -> LambdaT MCMem -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan_op
[(Param LParamMem, PatElemT LParamMem, Type, SubExp)]
-> ((Param LParamMem, PatElemT LParamMem, Type, SubExp)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param LParamMem]
-> [PatElemT LParamMem]
-> [Type]
-> [SubExp]
-> [(Param LParamMem, PatElemT LParamMem, Type, SubExp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [PatElemT LParamMem]
pes [Type]
ts ([SubExp] -> [(Param LParamMem, PatElemT LParamMem, Type, SubExp)])
-> [SubExp]
-> [(Param LParamMem, PatElemT LParamMem, Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) (((Param LParamMem, PatElemT LParamMem, Type, SubExp)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> ((Param LParamMem, PatElemT LParamMem, Type, SubExp)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe, Type
t, SubExp
ne) -> do
VName
acc <-
case Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
[] -> VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM MCMem HostEnv Multicore VName)
-> VName -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p
[SubExp]
_ -> do
let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"local_acc" PrimType
pt (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
let read_carry_in :: ImpM lore r op ()
read_carry_in =
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) (TExp Int64
iter' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
read_neutral :: ImpM lore r op ()
read_neutral =
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []
TExp Bool
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TExp Int64
iter' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) ImpM MCMem HostEnv Multicore ()
forall lore r op. ImpM lore r op ()
read_neutral ImpM MCMem HostEnv Multicore ()
forall lore r op. ImpM lore r op ()
read_carry_in
VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc
Code
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
(VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
iter'
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"stage 3 scan body" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])]
-> (([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp MCMem]
-> [[KernelResult]]
-> [[VName]]
-> [([PatElemT LParamMem], SegBinOp MCMem, [KernelResult],
[VName])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops [[KernelResult]]
per_scan_res [[VName]]
local_accs) ((([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> (([PatElemT LParamMem], SegBinOp MCMem, [KernelResult], [VName])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes, SegBinOp MCMem
scan_op, [KernelResult]
scan_res, [VName]
acc) ->
Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
scan_op) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
acc) (((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc') ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc') [TExp Int64]
vec_is
[(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) [TExp Int64]
vec_is
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LParamMem, SubExp, VName)]
-> ((PatElemT LParamMem, SubExp, VName)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [SubExp] -> [VName] -> [(PatElemT LParamMem, SubExp, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT LParamMem]
pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) [VName]
acc) (((PatElemT LParamMem, SubExp, VName)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, SubExp, VName)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
\(PatElemT LParamMem
pe, SubExp
se, VName
acc') -> do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc' [TExp Int64]
vec_is SubExp
se []
[Param]
free_params' <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
prebody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
body) (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
iter])
let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"scan_stage_3" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
iter) (Code
body_allocs Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
prebody) Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params' (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
segmentedScan ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.Code
segmentedScan :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
segmentedScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segmented segScan" Maybe Exp
forall a. Maybe a
Nothing
ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
TV Int64
segment_iter <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"segment_iter" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
Code
body <- TExp Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
compileSegScanBody (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
segment_iter) Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody
[Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
segment_iter])
let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"seg_scan" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
segment_iter) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
compileSegScanBody ::
Imp.TExp Int64 ->
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.Code
compileSegScanBody :: TExp Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
compileSegScanBody TExp Int64
segment_i Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
let per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$
[(SegBinOp MCMem, [PatElemT LParamMem])]
-> ((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp MCMem]
-> [[PatElemT LParamMem]]
-> [(SegBinOp MCMem, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp MCMem]
scan_ops [[PatElemT LParamMem]]
per_scan_pes) (((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((SegBinOp MCMem, [PatElemT LParamMem])
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
scan_op, [PatElemT LParamMem]
scan_pes) -> do
Maybe (Exp MCMem) -> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> ImpM MCMem HostEnv Multicore ())
-> Scope MCMem -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ LambdaT MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [LParam MCMem])
-> LambdaT MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan_op
let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ (LambdaT MCMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LParamMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
scan_op
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params ([SubExp] -> [(Param LParamMem, SubExp)])
-> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) (((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
let inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
ns_64
String
-> TExp Int64
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
(VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
segment_i
VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TExp Int64
i
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write to-scan values to parameters" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write mapped values results to memory" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LParamMem
pat) [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, KernelResult)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []
String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"combine with carry and write to memory" (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem) -> LambdaT MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan_op) (ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, PatElemT LParamMem, SubExp)]
-> ((Param LParamMem, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem]
-> [SubExp]
-> [(Param LParamMem, PatElemT LParamMem, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LParamMem]
scan_x_params [PatElemT LParamMem]
scan_pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem) -> LambdaT MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp MCMem
scan_op)) (((Param LParamMem, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, PatElemT LParamMem, SubExp)
-> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe, SubExp
se) -> do
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is) SubExp
se []
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []