{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.SegHist (compileSegHist) where
import Control.Monad.Except
import Data.List (foldl', genericLength, zip4, zip6)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.CodeGen.ImpGen.Kernels.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data SubhistosInfo = SubhistosInfo
{ SubhistosInfo -> VName
subhistosArray :: VName,
SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
}
data SegHistSlug = SegHistSlug
{ SegHistSlug -> HistOp KernelsMem
slugOp :: HistOp KernelsMem,
SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate :: AtomicUpdate KernelsMem KernelEnv
}
histoSpaceUsage ::
HistOp KernelsMem ->
Imp.Count Imp.Bytes (Imp.TExp Int64)
histoSpaceUsage :: HistOp KernelsMem -> Count Bytes (TExp Int64)
histoSpaceUsage HistOp KernelsMem
op =
[Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map
( Type -> Count Bytes (TExp Int64)
typeSize
(Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op)
(Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)
)
([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT KernelsMem -> [Type]) -> LambdaT KernelsMem -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op
computeHistoUsage ::
SegSpace ->
HistOp KernelsMem ->
CallKernelGen
( Imp.Count Imp.Bytes (Imp.TExp Int64),
Imp.Count Imp.Bytes (Imp.TExp Int64),
SegHistSlug
)
computeHistoUsage :: SegSpace
-> HistOp KernelsMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space HistOp KernelsMem
op = do
let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims
TV Int64
num_subhistos <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"num_subhistos" PrimType
int32
[SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp)
-> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op)) (((VName, SubExp) -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp)
-> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
Type
dest_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
dest
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest
VName
subhistos_mem <-
String -> Space -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos_mem") (String -> Space
Space String
"device")
let subhistos_shape :: Shape
subhistos_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_subhistos])
Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
subhistos_membind :: MemBind
subhistos_membind =
VName -> IxFun -> MemBind
ArrayIn VName
subhistos_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
VName
subhistos <-
String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray
(VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos")
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
Shape
subhistos_shape
MemBind
subhistos_membind
SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$
VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
String -> Space
Space String
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = do
let num_elems :: TExp Int64
num_elems =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(*) (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos) ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t
let subhistos_mem_size :: Count Bytes (TExp Int64)
subhistos_mem_size =
TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
VName -> Count Bytes (TExp Int64) -> Space -> CallKernelGen ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TExp Int64)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
Type
subhistos_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
subhistos
let slice :: Slice (TExp Int64)
slice =
[TExp Int64] -> Slice (TExp Int64) -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) (Slice (TExp Int64) -> Slice (TExp Int64))
-> Slice (TExp Int64) -> Slice (TExp Int64)
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> DimIndex (TExp Int64))
-> [(VName, SubExp)] -> Slice (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0 (TExp Int64 -> DimIndex (TExp Int64))
-> ((VName, SubExp) -> TExp Int64)
-> (VName, SubExp)
-> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
Slice (TExp Int64) -> Slice (TExp Int64) -> Slice (TExp Int64)
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0]
VName -> Slice (TExp Int64) -> SubExp -> CallKernelGen ()
forall lore r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate VName
subhistos Slice (TExp Int64)
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
let h :: Count Bytes (TExp Int64)
h = HistOp KernelsMem -> Count Bytes (TExp Int64)
histoSpaceUsage HistOp KernelsMem
op
segmented_h :: Count Bytes (TExp Int64)
segmented_h = Count Bytes (TExp Int64)
h Count Bytes (TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall a. Num a => a -> a -> a
* [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TExp Int64))
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) ([SubExp] -> [Count Bytes (TExp Int64)])
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM KernelsMem HostEnv HostOp HostEnv
-> ImpM KernelsMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Count Bytes (TExp Int64)
h,
Count Bytes (TExp Int64)
segmented_h,
HistOp KernelsMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate KernelsMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp KernelsMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate KernelsMem KernelEnv -> SegHistSlug)
-> AtomicUpdate KernelsMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
AtomicBinOp
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv)
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op
)
prepareAtomicUpdateGlobal ::
Maybe Locking ->
[VName] ->
SegHistSlug ->
CallKernelGen
( Maybe Locking,
[Imp.TExp Int64] -> InKernelGen ()
)
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
case (Maybe Locking
l, SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TExp Int64]
dims =
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [ TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug),
HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)
]
VName
locks <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"hist_locks" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
(Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
$cp1Ord :: Eq Passage
Ord)
bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody
| Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases KernelsMem) -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (AliasTable
-> KernelBody KernelsMem -> KernelBody (Aliases KernelsMem)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody KernelsMem
kbody) =
Passage
MayBeMultiPass
| Bool
otherwise =
Passage
MustBeSinglePass
prepareIntermediateArraysGlobal ::
Passage ->
Imp.TExp Int32 ->
Imp.TExp Int64 ->
[SegHistSlug] ->
CallKernelGen
( Imp.TExp Int32,
[[Imp.TExp Int64] -> InKernelGen ()]
)
prepareIntermediateArraysGlobal :: Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage TExp Int32
hist_T TExp Int64
hist_N [SegHistSlug]
slugs = do
TExp Int64
hist_H <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_H" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
TExp Double
hist_RF <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_RF" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
[TExp Double] -> TExp Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Double) -> [SegHistSlug] -> [TExp Double]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64 -> TExp Double)
-> (SegHistSlug -> TExp Int64) -> SegHistSlug -> TExp Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ [SegHistSlug] -> TExp Double
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
TExp Int32
hist_el_size <- String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_el_size" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int32) -> [SegHistSlug] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs
TExp Double
hist_C_max <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_C_max" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_ct_min
TExp Int32
hist_M_min <-
String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_M_min" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C_max
let hist_L2_def :: Int64
hist_L2_def = Int64
4 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024
TV Any
hist_L2 <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"L2_size" PrimType
int32
Maybe Name
entry <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> Name -> SizeClass -> HostOp
Imp.GetSize
(TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)
(Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)))
(SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$ Name -> Int64 -> SizeClass
Imp.SizeBespoke (String -> Name
nameFromString String
"L2_for_histogram") Int64
hist_L2_def
let hist_L2_ln_sz :: TExp Double
hist_L2_ln_sz = TExp Double
16 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
4
TExp Double
hist_RACE_exp <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_RACE_exp" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TExp Double
1 (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$
(TExp Double
hist_k_RF TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RF)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ (TExp Double
hist_L2_ln_sz TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)
TV Int32
hist_S <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"hist_S" PrimType
int32
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
(TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
hist_H)
(TV Int32
hist_S TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
(CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- case Passage
passage of
Passage
MayBeMultiPass ->
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
(TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TPrimExp Any ExpLeaf -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TV Any -> TPrimExp Any ExpLeaf
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
Passage
MustBeSinglePass ->
TExp Int32
1
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_RACE_exp
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S
[[TExp Int64] -> InKernelGen ()]
histograms <-
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()]
forall a b. (a, b) -> b
snd
((Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()])
-> ImpM
KernelsMem
HostEnv
HostOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> ImpM KernelsMem HostEnv HostOp [[TExp Int64] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
KernelsMem
HostEnv
HostOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(TPrimExp Any ExpLeaf
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp (TV Any -> TPrimExp Any ExpLeaf
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Int32
hist_M_min (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S) TExp Double
hist_RACE_exp)
Maybe Locking
forall a. Maybe a
Nothing
[SegHistSlug]
slugs
(TExp Int32, [[TExp Int64] -> InKernelGen ()])
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms)
where
hist_k_ct_min :: TExp Double
hist_k_ct_min = TExp Double
2
hist_k_RF :: TExp Double
hist_k_RF = TExp Double
0.75
hist_F_L2 :: TExp Double
hist_F_L2 = TExp Double
0.4
r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp KernelsMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicLocking {} ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)))
AtomicUpdate KernelsMem KernelEnv
_ ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op))
slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp KernelsMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicLocking {} ->
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$
[Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)
AtomicUpdate KernelsMem KernelEnv
_ ->
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$
[Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)
onOp :: TPrimExp Any ExpLeaf
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp TPrimExp Any ExpLeaf
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TExp Double
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
let SegHistSlug HistOp KernelsMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op = SegHistSlug
slug
hist_H :: TExp Int64
hist_H = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op
TExp Int64
hist_H_chk <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_H_chk" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H_chk
TExp Double
hist_k_max <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_k_max" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
(TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* (TPrimExp Any ExpLeaf -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Any ExpLeaf
hist_L2 TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
(TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_N)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T
TExp Int64
hist_u <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_u" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicPrim {} -> TExp Int64
2
AtomicUpdate KernelsMem KernelEnv
_ -> TExp Int64
1
TExp Double
hist_C <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_C" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64
hist_u TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk) TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_max
TExp Int32
hist_M <- String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_M" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim {} -> TExp Int32
1
AtomicUpdate KernelsMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_k_max
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_C
TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M
[VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest
VName
sub_mem <-
(MemLocation -> VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLocation -> VName
memLocationName (ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
ArrayEntry -> MemLocation
entryArrayLocation
(ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
String -> Space
Space String
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TExp Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
(Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug
(Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op')
histKernelGlobalPass ::
[PatElem KernelsMem] ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
SegSpace ->
[SegHistSlug] ->
KernelBody KernelsMem ->
[[Imp.TExp Int64] -> InKernelGen ()] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelGlobalPass :: [PatElem KernelsMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem KernelsMem]
map_pes Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody [[TExp Int64] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
space_sizes_64 :: [TExp Int64]
space_sizes_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64)
-> (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
space_sizes
total_w_64 :: TExp Int64
total_w_64 = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
space_sizes_64
[TExp Int64]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> ImpM KernelsMem HostEnv HostOp [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> ImpM KernelsMem HostEnv HostOp [TExp Int64])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> ImpM KernelsMem HostEnv HostOp [TExp Int64]
forall a b. (a -> b) -> a -> b
$ \SubExp
w ->
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_H_chk" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_global" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[TExp Int32]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> ImpM KernelsMem KernelEnv KernelOp [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> ImpM KernelsMem KernelEnv KernelOp [TExp Int32])
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> ImpM KernelsMem KernelEnv KernelOp [TExp Int32]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"subhisto_ind" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
)
let gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
num_threads :: TExp Int64
num_threads = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TExp Int64
-> TExp Int64
-> TExp Int64
-> (TExp Int64 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int64
gtid TExp Int64
num_threads TExp Int64
total_w_64 ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
offset -> do
(VName -> TExp Int32 -> InKernelGen ())
-> [VName] -> [TExp Int32] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int32 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
space_is ([TExp Int32] -> InKernelGen ()) -> [TExp Int32] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int32) -> [TExp Int64] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 ([TExp Int64] -> [TExp Int32]) -> [TExp Int64] -> [TExp Int32]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
space_sizes_64 TExp Int64
offset
let input_in_bounds :: TExp Bool
input_in_bounds = TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
total_w_64
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, KernelResult)]
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [KernelResult] -> [(PatElemT LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes [KernelResult]
map_res) (((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
(PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
(((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
let ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [KernelResult]
red_res
perOp :: [KernelResult] -> [[KernelResult]]
perOp = [Int] -> [KernelResult] -> [[KernelResult]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [KernelResult] -> [[KernelResult]])
-> [Int] -> [KernelResult] -> [[KernelResult]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(HistOp KernelsMem, [TExp Int64] -> InKernelGen (), KernelResult,
[KernelResult], TExp Int32, TExp Int64)]
-> ((HistOp KernelsMem, [TExp Int64] -> InKernelGen (),
KernelResult, [KernelResult], TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [[TExp Int64] -> InKernelGen ()]
-> [KernelResult]
-> [[KernelResult]]
-> [TExp Int32]
-> [TExp Int64]
-> [(HistOp KernelsMem, [TExp Int64] -> InKernelGen (),
KernelResult, [KernelResult], TExp Int32, TExp Int64)]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [[TExp Int64] -> InKernelGen ()]
histograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
perOp [KernelResult]
vs) [TExp Int32]
subhisto_inds [TExp Int64]
hist_H_chks) (((HistOp KernelsMem, [TExp Int64] -> InKernelGen (), KernelResult,
[KernelResult], TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp KernelsMem, [TExp Int64] -> InKernelGen (),
KernelResult, [KernelResult], TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
[TExp Int64] -> InKernelGen ()
do_op,
KernelResult
bucket,
[KernelResult]
vs',
TExp Int32
subhisto_ind,
TExp Int64
hist_H_chk
) -> do
let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk
bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
hist_H_chk)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w'
vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TExp Int64]
bucket_is =
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is)
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind, TExp Int64
bucket']
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[(Param LetDecMem, KernelResult)]
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [KernelResult] -> [(Param LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [KernelResult]
vs') (((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
histKernelGlobal ::
[PatElem KernelsMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody KernelsMem ->
CallKernelGen ()
histKernelGlobal :: [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_threads :: TExp Int32
num_threads = TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing
(TExp Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms) <-
Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal
(KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody)
TExp Int32
num_threads
(SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes)
[SegHistSlug]
slugs
String
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
[PatElem KernelsMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
[PatElem KernelsMem]
map_pes
Count NumGroups (TExp Int64)
num_groups'
Count GroupSize (TExp Int64)
group_size'
SegSpace
space
[SegHistSlug]
slugs
KernelBody KernelsMem
kbody
[[TExp Int64] -> InKernelGen ()]
histograms
TExp Int32
hist_S
TExp Int32
chk_i
type InitLocalHistograms =
[ ( [VName],
SubExp ->
InKernelGen
( [VName],
[Imp.TExp Int64] -> InKernelGen ()
)
)
]
prepareIntermediateArraysLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
SegSpace ->
[SegHistSlug] ->
CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group Count NumGroups (TExp Int64)
groups_per_segment SegSpace
space [SegHistSlug]
slugs = do
TExp Int64
num_segments <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"num_segments" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
(SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int64
-> SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
onOp TExp Int64
num_segments) [SegHistSlug]
slugs
where
onOp :: TExp Int64
-> SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
onOp TExp Int64
num_segments (SegHistSlug HistOp KernelsMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op) = do
TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of subhistograms in global memory" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op <-
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
let lock_shape :: Shape
lock_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$
TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
hist_H_chk]
let dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape
VName
locks <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp Int64]
dims (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return (DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate KernelsMem KernelEnv
f (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> Locking -> DoAtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 [TExp Int64] -> [TExp Int64]
forall a. a -> a
id
let init_local_subhistos :: SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
[VName]
local_subhistos <-
[Type]
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp KernelsMem -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp KernelsMem
op) ((Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName])
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let sub_local_shape :: Shape
sub_local_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group]
Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` SubExp
hist_H_chk)
String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray
String
"subhistogram_local"
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
Shape
sub_local_shape
(String -> Space
Space String
"local")
DoAtomicUpdate KernelsMem KernelEnv
do_op' <- SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op SubExp
hist_H_chk
([VName], [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
local_subhistos, DoAtomicUpdate KernelsMem KernelEnv
do_op' (String -> Space
Space String
"local") [VName]
local_subhistos)
[VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
glob_subhistos, SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos)
histKernelLocalPass ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem KernelsMem] ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
SegSpace ->
[SegHistSlug] ->
KernelBody KernelsMem ->
InitLocalHistograms ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem KernelsMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem KernelsMem]
map_pes
Count NumGroups (TExp Int64)
num_groups
Count GroupSize (TExp Int64)
group_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody KernelsMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
(VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
segment_size' :: TExp Int64
segment_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
segment_size
TExp Int64
num_segments <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"num_segments" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims
[TV Int64]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> ImpM KernelsMem HostEnv HostOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> ImpM KernelsMem HostEnv HostOp [TV Int64])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> ImpM KernelsMem HostEnv HostOp [TV Int64]
forall a b. (a -> b) -> a -> b
$ \SubExp
w ->
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"hist_H_chk" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes <- [(SegHistSlug, TV Int64)]
-> ((SegHistSlug, TV Int64)
-> ImpM
KernelsMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
KernelsMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegHistSlug] -> [TV Int64] -> [(SegHistSlug, TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) (((SegHistSlug, TV Int64)
-> ImpM
KernelsMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
KernelsMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)])
-> ((SegHistSlug, TV Int64)
-> ImpM
KernelsMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
KernelsMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
let histo_dims :: [TExp Int64]
histo_dims =
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)))
TExp Int64
histo_size <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"histo_size" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
histo_dims
let group_hists_size :: TExp Int64
group_hists_size =
TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
histo_size
TExp Int32
init_per_thread <-
String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"init_per_thread" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64
group_hists_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size
([TExp Int64], TExp Int64, TExp Int32)
-> ImpM
KernelsMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)
String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_local" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TExp Int32
flat_segment_id <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"flat_segment_id" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
TExp Int32
gid_in_segment <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"gid_in_segment" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
TExp Int32
pgtid_in_segment <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"pgtid_in_segment" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
threads_per_segment <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"threads_per_segment" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
(VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
segment_is ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims) (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms <- [(([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)]
-> ((([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [TV Int64]
-> [(([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) (((([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())])
-> ((([VName],
SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
\(([VName]
glob_subhistos, SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
([VName]
local_subhistos, [TExp Int64] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ()))
-> SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
([VName], [TExp Int64] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_H_chk
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op)
TExp Int32
thread_local_subhisto_i <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"thread_local_subhisto_i" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_group
let onSlugs :: (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f =
[(SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))]
-> ((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([TExp Int64], TExp Int64, TExp Int32)]
-> [(SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes) (((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ())
-> ((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
_), ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)) ->
SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread
let onAllHistograms :: (VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f =
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread -> do
let group_hists_size :: TExp Int32
group_hists_size = TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
[((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\((VName
dest_global, VName
dest_local), SubExp
ne) ->
String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"j" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
j_offset <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"j_offset" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
j
TExp Int32
local_subhisto_i <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"local_subhisto_i" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
global_bucket_is :: [TExp Int64]
global_bucket_is =
[TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
[TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
TExp Int32
global_subhisto_i <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"global_subhisto_i" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f
VName
dest_local
VName
dest_global
(SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)
SubExp
ne
TExp Int32
local_subhisto_i
TExp Int32
global_subhisto_i
[TExp Int64]
local_bucket_is
[TExp Int64]
global_bucket_is
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ())
-> (VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp KernelsMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TExp Int64]
local_bucket_is [TExp Int64]
global_bucket_is ->
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TExp Int64]
global_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
0] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
local_is :: [TExp Int64]
local_is = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is
TExp Bool -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
(TExp Int32
global_subhisto_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
(VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest_local [TExp Int64]
local_is (VName -> SubExp
Var VName
dest_global) [TExp Int64]
global_is)
( Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest_local ([TExp Int64]
local_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is) SubExp
ne []
)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
TExp Int32
-> TExp Int32
-> TExp Int32
-> (TExp Int32 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int32
pgtid_in_segment TExp Int32
threads_per_segment (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
segment_size') ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
ie -> do
VName -> TExp Int32 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
i_in_segment TExp Int32
ie
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [SubExp]
red_res
perOp :: [SubExp] -> [[SubExp]]
perOp = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [SubExp] -> [[SubExp]])
-> [Int] -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes [SubExp]
map_res) (((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
(PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
space_is)
SubExp
se
[]
[(HistOp KernelsMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
SubExp, [SubExp])]
-> ((HistOp KernelsMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
SubExp, [SubExp])
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [SubExp]
-> [[SubExp]]
-> [(HistOp KernelsMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
SubExp, [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [SubExp]
buckets ([SubExp] -> [[SubExp]]
perOp [SubExp]
vs)) (((HistOp KernelsMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
SubExp, [SubExp])
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp KernelsMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
SubExp, [SubExp])
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
([(VName, VName)]
_, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op),
SubExp
bucket,
[SubExp]
vs'
) -> do
let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
bucket
dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w'
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk)
bucket_is :: [TExp Int64]
bucket_is = [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
chk_beg]
vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
_histo_size TExp Int32
bins_per_thread -> do
TV Int64
trunc_H <-
String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"trunc_H" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_H_chk (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
histo_dims
let trunc_histo_dims :: [TExp Int64]
trunc_histo_dims =
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
trunc_H TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)))
TExp Int32
trunc_histo_size <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"histo_size" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
trunc_histo_dims
String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" TExp Int32
bins_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"j" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
global_bucket_is :: [TExp Int64]
global_bucket_is =
[TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
[TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
([Param LetDecMem]
xparams, [Param LetDecMem]
yparams) =
Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$
LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
subhisto) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
(Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int64
0 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"subhisto_id" (TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
yparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
yp, VName
subhisto) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
(Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
yp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
[Param LetDecMem] -> Body KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
xparams (Body KernelsMem -> InKernelGen ())
-> Body KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT KernelsMem -> Body KernelsMem)
-> LambdaT KernelsMem -> Body KernelsMem
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TExp Int64]
global_is =
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_is
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment]
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
global_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
global_dest) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
global_dest [TExp Int64]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp) []
histKernelLocal ::
TV Int32 ->
Count NumGroups (Imp.TExp Int64) ->
[PatElem KernelsMem] ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody KernelsMem ->
CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_group
InitLocalHistograms
init_histograms <-
TV Int32
-> Count NumGroups (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment SegSpace
space [SegHistSlug]
slugs
String
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem KernelsMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_group_var
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem KernelsMem]
map_pes
Count NumGroups (TExp Int64)
num_groups'
Count GroupSize (TExp Int64)
group_size'
SegSpace
space
[SegHistSlug]
slugs
KernelBody KernelsMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
3
AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
4
AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
6
localMemoryCase ::
[PatElem KernelsMem] ->
Imp.TExp Int32 ->
SegSpace ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody KernelsMem ->
CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem KernelsMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims
TV Int64
hist_L <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"hist_L" PrimType
int32
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_L) SizeClass
Imp.SizeLocalMemory
TV Any
max_group_size <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"max_group_size" PrimType
int32
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size) SizeClass
Imp.SizeGroup
let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size
Count NumGroups SubExp
num_groups <-
(TV Int64 -> Count NumGroups SubExp)
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM KernelsMem HostEnv HostOp (TV Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM KernelsMem HostEnv HostOp (TV Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_groups" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size
let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
TExp Double
hist_m' <-
String
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_m_prime" (TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM KernelsMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64
( TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
hist_el_size))
(TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups'))
)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H
let hist_B :: TExp Int64
hist_B = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
TExp Int64
hist_M0 <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_M0" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 TExp Double
hist_m') TExp Int64
hist_B
let q_small :: TExp Int64
q_small = TExp Int64
2
TExp Int64
hist_Nout <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_Nout" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims
TExp Int64
hist_Nin <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_Nin" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes
TExp Int64
work_asymp_M_max <-
if Bool
segmented
then do
TExp Int32
hist_T_hist_min <-
String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_T_hist_min" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout) (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout
let r :: TExp Int32
r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
hist_B
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"work_asymp_M_max" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
else
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"work_asymp_M_max" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_Nout TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_N)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` ( (TExp Int64
q_small TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
)
TV Int32
hist_M <- String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"hist_M" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TV Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_M0 TExp Int64
work_asymp_M_max
let hist_M_nonzero :: TExp Int32
hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
TExp Int64
hist_C <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_C" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_B TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_M0
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
work_asymp_M_max
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int64
local_mem_needed <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"local_mem_needed" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int32
hist_S <-
String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_S" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
local_mem_needed) TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L
let max_S :: TExp Int32
max_S = case KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody of
Passage
MustBeSinglePass -> TExp Int32
1
Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs
Count NumGroups (TExp Int64)
groups_per_segment <-
if Bool
segmented
then
(TExp Int64 -> Count NumGroups (TExp Int64))
-> ImpM KernelsMem HostEnv HostOp (TExp Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TExp Int64 -> Count NumGroups (TExp Int64)
forall u e. e -> Count u e
Count (ImpM KernelsMem HostEnv HostOp (TExp Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups (TExp Int64)))
-> ImpM KernelsMem HostEnv HostOp (TExp Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall a b. (a -> b) -> a -> b
$
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"groups_per_segment" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_Nout
else Count NumGroups (TExp Int64)
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumGroups (TExp Int64)
num_groups'
let pick_local :: TExp Bool
pick_local =
TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int64
hist_H
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int64
local_mem_needed TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
hist_C TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
hist_B
TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0
run :: CallKernelGen ()
run = do
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment
TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal
TV Int32
hist_M
Count NumGroups (TExp Int64)
groups_per_segment
[PatElem KernelsMem]
map_pes
Count NumGroups SubExp
num_groups
Count GroupSize SubExp
group_size
SegSpace
space
TExp Int32
hist_S
[SegHistSlug]
slugs
KernelBody KernelsMem
kbody
(TExp Bool, CallKernelGen ())
-> CallKernelGen (TExp Bool, CallKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Bool
pick_local, CallKernelGen ()
run)
compileSegHist ::
Pattern KernelsMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[HistOp KernelsMem] ->
KernelBody KernelsMem ->
CallKernelGen ()
compileSegHist :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegHist (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp KernelsMem]
ops KernelBody KernelsMem
kbody = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
num_red_res :: Int
num_red_res = [HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp KernelsMem]
ops)
([PatElemT LetDecMem]
all_red_pes, [PatElemT LetDecMem]
map_pes) = Int
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem KernelsMem]
[PatElemT LetDecMem]
pes
segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims
([Count Bytes (TExp Int64)]
op_hs, [Count Bytes (TExp Int64)]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes (TExp Int64), Count Bytes (TExp Int64),
SegHistSlug)]
-> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug]))
-> ImpM
KernelsMem
HostEnv
HostOp
[(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ImpM
KernelsMem
HostEnv
HostOp
([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp KernelsMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug))
-> [HistOp KernelsMem]
-> ImpM
KernelsMem
HostEnv
HostOp
[(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp KernelsMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp KernelsMem]
ops
TExp Int64
h <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"h" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_hs
TExp Int64
seg_h <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"seg_h" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_seg_hs
TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TExp Int64
seg_h TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let hist_B :: TExp Int64
hist_B = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
TExp Int64
hist_H <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_H" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (HistOp KernelsMem -> TExp Int64)
-> [HistOp KernelsMem] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (HistOp KernelsMem -> SubExp) -> HistOp KernelsMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth) [HistOp KernelsMem]
ops
let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
AtomicUpdate KernelsMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
TExp Int64
hist_el_size <-
String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_el_size" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(+) (TExp Int64
h TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_H) ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(SegHistSlug -> Maybe (TExp Int64))
-> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe (TExp Int64)
forall a. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs
TExp Int64
hist_N <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_N" TExp Int64
segment_size
TExp Int32
hist_RF <-
String -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"hist_RF" (TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
let hist_T :: TExp Int32
hist_T = TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_N
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_el_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
h
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
seg_h
(TExp Bool
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
[PatElem KernelsMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody KernelsMem
kbody
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
let pes_per_op :: [[PatElemT LetDecMem]]
pes_per_op = [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp KernelsMem -> [VName]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp KernelsMem]
ops) [PatElemT LetDecMem]
all_red_pes
[(SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)]
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElemT LetDecMem]]
-> [HistOp KernelsMem]
-> [(SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElemT LetDecMem]]
pes_per_op [HistOp KernelsMem]
ops) (((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ())
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElemT LetDecMem]
red_pes, HistOp KernelsMem
op) -> do
let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
[(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_pes [VName]
subhistos) (((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ())
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
subhisto) -> do
VName
pe_mem <-
MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation
(ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
VName
subhisto_mem <-
MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation
(ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
subhisto
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_histos TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let num_buckets :: SubExp
num_buckets = HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op
VName
bucket_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
VName
subhistogram_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
[VName]
vector_ids <-
(SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ImpM KernelsMem HostEnv HostOp VName
-> SubExp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. a -> b -> a
const (ImpM KernelsMem HostEnv HostOp VName
-> SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp VName
-> SubExp
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"vector_id") ([SubExp] -> ImpM KernelsMem HostEnv HostOp [VName])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op
VName
flat_gtid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid"
let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
bucket_id, SubExp
num_buckets)]
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_histos)]
let segred_op :: SegBinOp KernelsMem
segred_op = Commutativity
-> LambdaT KernelsMem -> [SubExp] -> Shape -> SegBinOp KernelsMem
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Commutative (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op) Shape
forall a. Monoid a => a
mempty
Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT LetDecMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp KernelsMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> [(SubExp, [TExp Int64])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
((VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
)
where
segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space