module Futhark.CodeGen.ImpGen.Multicore.SegRed
( compileSegRed,
compileSegRed',
)
where
import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Util (chunks)
import Prelude hiding (quot, rem)
type DoSegBody = (([(SubExp, [Imp.TExp Int64])] -> MulticoreGen ()) -> MulticoreGen ())
compileSegRed ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.Code
compileSegRed :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
compileSegRed Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks =
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
compileSegRed' Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks (DoSegBody -> MulticoreGen Code) -> DoSegBody -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> MulticoreGen ()
red_cont ->
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElemT LetDecMem]
map_arrs = Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp MCMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp MCMem]
reds) ([PatElemT LetDecMem] -> [PatElemT LetDecMem])
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LetDecMem
pat
(PatElemT LetDecMem -> KernelResult -> MulticoreGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem MCMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_arrs [KernelResult]
map_res
[(SubExp, [TExp Int64])] -> MulticoreGen ()
red_cont ([(SubExp, [TExp Int64])] -> MulticoreGen ())
-> [(SubExp, [TExp Int64])] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TExp Int64]] -> [(SubExp, [TExp Int64])])
-> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat []
compileSegRed' ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
TV Int32 ->
DoSegBody ->
MulticoreGen Imp.Code
compileSegRed' :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
compileSegRed' Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
| [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
nonsegmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
| Bool
otherwise =
Pattern MCMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Code
segmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp MCMem
slugOp :: SegBinOp MCMem,
SegBinOpSlug -> [VName]
slugResArrs :: [VName]
}
slugBody :: SegBinOpSlug -> Body MCMem
slugBody :: SegBinOpSlug -> Body MCMem
slugBody = LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem)
-> (SegBinOpSlug -> LambdaT MCMem) -> SegBinOpSlug -> Body MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> (SegBinOpSlug -> SegBinOp MCMem)
-> SegBinOpSlug
-> LambdaT MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams = LambdaT MCMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LetDecMem])
-> (SegBinOpSlug -> LambdaT MCMem)
-> SegBinOpSlug
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> (SegBinOpSlug -> SegBinOp MCMem)
-> SegBinOpSlug
-> LambdaT MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral (SegBinOp MCMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp MCMem) -> SegBinOpSlug -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp MCMem -> Shape)
-> (SegBinOpSlug -> SegBinOp MCMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
accParams, nextParams :: SegBinOpSlug -> [LParam MCMem]
accParams :: SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
nonsegmentedReduction ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
TV Int32 ->
DoSegBody ->
MulticoreGen Imp.Code
nonsegmentedReduction :: Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
nonsegmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody = MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
[[VName]]
thread_res_arrs <- String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
"reduce_stage_1_tid_res_arr" (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
nsubtasks) [SegBinOp MCMem]
reds
let slugs1 :: [SegBinOpSlug]
slugs1 = (SegBinOp MCMem -> [VName] -> SegBinOpSlug)
-> [SegBinOp MCMem] -> [[VName]] -> [SegBinOpSlug]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds [[VName]]
thread_res_arrs
nsubtasks' :: TExp Int32
nsubtasks' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
nsubtasks
SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()
reductionStage1 SegSpace
space [SegBinOpSlug]
slugs1 DoSegBody
kbody
[SegBinOp MCMem]
reds2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
reds
let slugs2 :: [SegBinOpSlug]
slugs2 = (SegBinOp MCMem -> [VName] -> SegBinOpSlug)
-> [SegBinOp MCMem] -> [[VName]] -> [SegBinOpSlug]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds2 [[VName]]
thread_res_arrs
Pattern MCMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pattern MCMem
pat SegSpace
space TExp Int32
nsubtasks' [SegBinOpSlug]
slugs2
reductionStage1 ::
SegSpace ->
[SegBinOpSlug] ->
DoSegBody ->
MulticoreGen ()
reductionStage1 :: SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()
reductionStage1 SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
TV Int64
flat_idx <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"iter" PrimType
int64
([[VName]]
slug_local_accs, Code
prebody) <- MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code))
-> MulticoreGen [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall a b. (a -> b) -> a -> b
$ do
Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LetDecMem] -> Scope MCMem)
-> [Param LetDecMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LetDecMem])
-> [SegBinOpSlug] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
SegBinOpSlug -> [Param LetDecMem]
slugParams [SegBinOpSlug]
slugs
[SegBinOpSlug]
-> (SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOpSlug]
slugs ((SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]])
-> (SegBinOpSlug -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let shape :: Shape
shape = SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp MCMem -> Shape) -> SegBinOp MCMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) (((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> ((Param LetDecMem, SubExp)
-> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
ne) -> do
VName
acc <-
case Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p of
Prim PrimType
pt
| Shape
shape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty ->
TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName)
-> ImpM MCMem HostEnv Multicore (TV Any)
-> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"local_acc" PrimType
pt
| Bool
otherwise ->
String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"local_acc" PrimType
pt Shape
shape Space
DefaultSpace
Type
_ ->
VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM MCMem HostEnv Multicore VName)
-> VName -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []
VName -> ImpM MCMem HostEnv Multicore VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc
Code
fbody <- MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
(VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
DoSegBody
kbody DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
let all_red_res' :: [[(SubExp, [TExp Int64])]]
all_red_res' = [SegBinOp MCMem]
-> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [(SubExp, [TExp Int64])]
all_red_res
[([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])]
-> (([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
-> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[(SubExp, [TExp Int64])]]
-> [SegBinOpSlug]
-> [[VName]]
-> [([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[(SubExp, [TExp Int64])]]
all_red_res' [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) ((([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
-> MulticoreGen ())
-> MulticoreGen ())
-> (([(SubExp, [TExp Int64])], SegBinOpSlug, [VName])
-> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([(SubExp, [TExp Int64])]
red_res, SegBinOpSlug
slug, [VName]
local_accs) ->
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
let lamtypes :: [Type]
lamtypes = LambdaT MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT MCMem -> [Type]) -> LambdaT MCMem -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem -> LambdaT MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Load accum params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, VName, Type)]
-> ((Param LetDecMem, VName, Type) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [VName] -> [Type] -> [(Param LetDecMem, VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [VName]
local_accs [Type]
lamtypes) (((Param LetDecMem, VName, Type) -> MulticoreGen ())
-> MulticoreGen ())
-> ((Param LetDecMem, VName, Type) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(Param LetDecMem
p, VName
local_acc, Type
t) ->
Bool -> MulticoreGen () -> MulticoreGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
t) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
local_acc) [TExp Int64]
vec_is
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Load next params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, (SubExp, [TExp Int64]))]
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LetDecMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) (((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ())
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Red body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_accs (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug)) (((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((VName, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(VName
local_acc, SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
local_acc [TExp Int64]
vec_is SubExp
se []
Code
postbody <- MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$
[(SegBinOpSlug, [VName])]
-> ((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug] -> [[VName]] -> [(SegBinOpSlug, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) (((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
-> MulticoreGen ())
-> ((SegBinOpSlug, [VName]) -> ImpM MCMem HostEnv Multicore [()])
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [VName]
local_accs) ->
[(VName, VName)]
-> ((VName, VName) -> MulticoreGen ())
-> ImpM MCMem HostEnv Multicore [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug) [VName]
local_accs) (((VName, VName) -> MulticoreGen ())
-> ImpM MCMem HostEnv Multicore [()])
-> ((VName, VName) -> MulticoreGen ())
-> ImpM MCMem HostEnv Multicore [()]
forall a b. (a -> b) -> a -> b
$ \(VName
acc, VName
local_acc) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc [VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space] (VName -> SubExp
Var VName
local_acc) []
[Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
prebody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
fbody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
postbody) (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx])
let (Code
body_allocs, Code
fbody') = Code -> (Code, Code)
extractAllocations Code
fbody
Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"segred_stage_1" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx) (Code
body_allocs Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
prebody) Code
fbody' Code
postbody [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
reductionStage2 ::
Pattern MCMem ->
SegSpace ->
Imp.TExp Int32 ->
[SegBinOpSlug] ->
MulticoreGen ()
reductionStage2 :: Pattern MCMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pattern MCMem
pat SegSpace
space TExp Int32
nsubtasks [SegBinOpSlug]
slugs = do
let per_red_pes :: [[PatElemT LetDecMem]]
per_red_pes = [SegBinOp MCMem] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LetDecMem
pat
phys_id :: TExp Int64
phys_id = VName -> TExp Int64
Imp.vi64 (SegSpace -> VName
segFlat SegSpace
space)
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the output" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOp MCMem, [PatElemT LetDecMem])]
-> ((SegBinOp MCMem, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp MCMem]
-> [[PatElemT LetDecMem]]
-> [(SegBinOp MCMem, [PatElemT LetDecMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SegBinOpSlug -> SegBinOp MCMem)
-> [SegBinOpSlug] -> [SegBinOp MCMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [[PatElemT LetDecMem]]
per_red_pes) (((SegBinOp MCMem, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ())
-> ((SegBinOp MCMem, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
red, [PatElemT LetDecMem]
red_res) ->
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_res ([SubExp] -> [(PatElemT LetDecMem, SubExp)])
-> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red) (((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, SubExp
ne) ->
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [TExp Int64]
vec_is SubExp
ne []
Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LetDecMem] -> Scope MCMem)
-> [Param LetDecMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LetDecMem])
-> [SegBinOpSlug] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
SegBinOpSlug -> [Param LetDecMem]
slugParams [SegBinOpSlug]
slugs
String
-> TExp Int32 -> (TExp Int32 -> MulticoreGen ()) -> MulticoreGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int32
nsubtasks ((TExp Int32 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int32 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i' -> do
VName -> PrimType -> TV Int32
forall t. VName -> PrimType -> TV t
mkTV (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64 TV Int32 -> TExp Int32 -> MulticoreGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int32
i'
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Apply main thread reduction" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOpSlug, [PatElemT LetDecMem])]
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LetDecMem]] -> [(SegBinOpSlug, [PatElemT LetDecMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LetDecMem]]
per_red_pes) (((SegBinOpSlug, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ())
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LetDecMem]
red_res) ->
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load acc params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, PatElemT LetDecMem)]
-> ((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [PatElemT LetDecMem] -> [(Param LetDecMem, PatElemT LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [PatElemT LetDecMem]
red_res) (((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ())
-> ((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, PatElemT LetDecMem
pe) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [TExp Int64]
vec_is
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load next params" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug)) (((Param LetDecMem, VName) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LetDecMem, VName) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
acc) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) (TExp Int64
phys_id TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"red body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_res (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug)) (((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(PatElemT LetDecMem
pe, SubExp
se') -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [TExp Int64]
vec_is SubExp
se' []
segmentedReduction ::
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
DoSegBody ->
MulticoreGen Imp.Code
segmentedReduction :: Pattern MCMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Code
segmentedReduction Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody =
MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
TV Int64
n_par_segments <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"segment_iter" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
Code
body <- TV Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> DoSegBody
-> MulticoreGen Code
compileSegRedBody TV Int64
n_par_segments Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
[Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
n_par_segments])
let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
Code -> MulticoreGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"segmented_segred" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
n_par_segments) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
compileSegRedBody ::
TV Int64 ->
Pattern MCMem ->
SegSpace ->
[SegBinOp MCMem] ->
DoSegBody ->
MulticoreGen Imp.Code
compileSegRedBody :: TV Int64
-> Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> DoSegBody
-> MulticoreGen Code
compileSegRedBody TV Int64
n_segments Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody = do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
ns_64
n_segments' :: TExp Int64
n_segments' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
n_segments
let per_red_pes :: [[PatElemT LetDecMem]]
per_red_pes = [SegBinOp MCMem] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LetDecMem
pat
MulticoreGen () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
TExp Int64
flat_idx <- String -> TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"flat_idx" (TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64))
-> TExp Int64 -> ImpM MCMem HostEnv Multicore (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
n_segments' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
inner_bound
(VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[([PatElemT LetDecMem], SegBinOp MCMem)]
-> (([PatElemT LetDecMem], SegBinOp MCMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LetDecMem]]
-> [SegBinOp MCMem] -> [([PatElemT LetDecMem], SegBinOp MCMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT LetDecMem]]
per_red_pes [SegBinOp MCMem]
reds) ((([PatElemT LetDecMem], SegBinOp MCMem) -> MulticoreGen ())
-> MulticoreGen ())
-> (([PatElemT LetDecMem], SegBinOp MCMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LetDecMem]
pes, SegBinOp MCMem
red) ->
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
pes (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) (((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, SubExp
ne) ->
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"main body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope MCMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LetDecMem] -> Scope MCMem)
-> [Param LetDecMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ (SegBinOp MCMem -> [Param LetDecMem])
-> [SegBinOp MCMem] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (LambdaT MCMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LetDecMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp MCMem]
reds
String
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
(TV Int64 -> TExp Int64 -> MulticoreGen ())
-> [TV Int64] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
TV Int64 -> TExp Int64 -> MulticoreGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
(<--)
((VName -> TV Int64) -> [VName] -> [TV Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
`mkTV` PrimType
int64) ([VName] -> [TV Int64]) -> [VName] -> [TV Int64]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is)
([TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64) (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
n_segments'))
VName -> TExp Int64 -> MulticoreGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TExp Int64
i
DoSegBody
kbody DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
let red_res' :: [[(SubExp, [TExp Int64])]]
red_res' = [Int] -> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp MCMem -> Int) -> [SegBinOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp MCMem -> [SubExp]) -> SegBinOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp MCMem]
reds) [(SubExp, [TExp Int64])]
all_red_res
[([PatElemT LetDecMem], SegBinOp MCMem, [(SubExp, [TExp Int64])])]
-> (([PatElemT LetDecMem], SegBinOp MCMem,
[(SubExp, [TExp Int64])])
-> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LetDecMem]]
-> [SegBinOp MCMem]
-> [[(SubExp, [TExp Int64])]]
-> [([PatElemT LetDecMem], SegBinOp MCMem,
[(SubExp, [TExp Int64])])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LetDecMem]]
per_red_pes [SegBinOp MCMem]
reds [[(SubExp, [TExp Int64])]]
red_res') ((([PatElemT LetDecMem], SegBinOp MCMem, [(SubExp, [TExp Int64])])
-> MulticoreGen ())
-> MulticoreGen ())
-> (([PatElemT LetDecMem], SegBinOp MCMem,
[(SubExp, [TExp Int64])])
-> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LetDecMem]
pes, SegBinOp MCMem
red, [(SubExp, [TExp Int64])]
res') ->
Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOp MCMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp MCMem
red) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accum" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let acc_params :: [Param LetDecMem]
acc_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ (LambdaT MCMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LetDecMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
[(Param LetDecMem, PatElemT LetDecMem)]
-> ((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [PatElemT LetDecMem] -> [(Param LetDecMem, PatElemT LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
acc_params [PatElemT LetDecMem]
pes) (((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ())
-> ((Param LetDecMem, PatElemT LetDecMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, PatElemT LetDecMem
pe) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new val" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let next_params :: [Param LetDecMem]
next_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp MCMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ (LambdaT MCMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT MCMem -> [Param LetDecMem])
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
[(Param LetDecMem, (SubExp, [TExp Int64]))]
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LetDecMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
next_params [(SubExp, [TExp Int64])]
res') (((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ())
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let lbody :: Body MCMem
lbody = (LambdaT MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT MCMem -> Body MCMem)
-> (SegBinOp MCMem -> LambdaT MCMem)
-> SegBinOp MCMem
-> Body MCMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp MCMem -> LambdaT MCMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) SegBinOp MCMem
red
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms Body MCMem
lbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
String -> MulticoreGen () -> MulticoreGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write back to res" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
pes (Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body MCMem
lbody)) (((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElemT LetDecMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(PatElemT LetDecMem
pe, SubExp
se') -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se' []