-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.Kernels.SegScan (compileSegScan) where

import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.CodeGen.ImpGen.Kernels.SegScan.SinglePass as SinglePass
import qualified Futhark.CodeGen.ImpGen.Kernels.SegScan.TwoPass as TwoPass
import Futhark.IR.KernelsMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans :: [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans [SegBinOp KernelsMem]
ops =
  SegBinOp :: forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp KernelsMem -> Commutativity)
-> [SegBinOp KernelsMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm [SegBinOp KernelsMem]
ops),
      segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp KernelsMem -> [SubExp])
-> [SegBinOp KernelsMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp KernelsMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda KernelsMem]
lams = (SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
ops
    xParams :: LambdaT lore -> [Param (LParamInfo lore)]
xParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT lore
lam)
    yParams :: LambdaT lore -> [Param (LParamInfo lore)]
yParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT lore
lam)
    lam' :: Lambda KernelsMem
lam' =
      Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
        { lambdaParams :: [LParam KernelsMem]
lambdaParams = (Lambda KernelsMem -> [Param LParamMem])
-> [Lambda KernelsMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
xParams [Lambda KernelsMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda KernelsMem -> [Param LParamMem])
-> [Lambda KernelsMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Param LParamMem]
forall lore. LambdaT lore -> [LParam lore]
yParams [Lambda KernelsMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda KernelsMem -> [Type]) -> [Lambda KernelsMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType [Lambda KernelsMem]
lams,
          lambdaBody :: BodyT KernelsMem
lambdaBody =
            BodyDec KernelsMem
-> Stms KernelsMem -> [SubExp] -> BodyT KernelsMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
              ()
              ([Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda KernelsMem -> Stms KernelsMem)
-> [Lambda KernelsMem] -> [Stms KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> (Lambda KernelsMem -> BodyT KernelsMem)
-> Lambda KernelsMem
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda KernelsMem]
lams))
              ((Lambda KernelsMem -> [SubExp]) -> [Lambda KernelsMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp])
-> (Lambda KernelsMem -> BodyT KernelsMem)
-> Lambda KernelsMem
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda KernelsMem]
lams)
        }

canBeSinglePass :: SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass :: SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass SegSpace
space [SegBinOp KernelsMem]
ops
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
    (SegBinOp KernelsMem -> Bool) -> [SegBinOp KernelsMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp KernelsMem -> Bool
forall lore. SegBinOp lore -> Bool
ok [SegBinOp KernelsMem]
ops =
    SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem)
forall a. a -> Maybe a
Just (SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem))
-> SegBinOp KernelsMem -> Maybe (SegBinOp KernelsMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp KernelsMem] -> SegBinOp KernelsMem
combineScans [SegBinOp KernelsMem]
ops
  | Bool
otherwise =
    Maybe (SegBinOp KernelsMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp lore -> Bool
ok SegBinOp lore
op =
      SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op))

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  KernelBody KernelsMem ->
  CallKernelGen ()
compileSegScan :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TPrimExp Int64 ExpLeaf
0 TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM KernelsMem HostEnv HostOp HostEnv
-> ImpM KernelsMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp KernelsMem
scan' <- SegSpace -> [SegBinOp KernelsMem] -> Maybe (SegBinOp KernelsMem)
canBeSinglePass SegSpace
space [SegBinOp KernelsMem]
scans ->
        Pattern KernelsMem
-> SegLevel
-> SegSpace
-> SegBinOp KernelsMem
-> KernelBody KernelsMem
-> CallKernelGen ()
SinglePass.compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space SegBinOp KernelsMem
scan' KernelBody KernelsMem
kbody
    Target
_ -> Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
TwoPass.compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody
  where
    n :: TPrimExp Int64 ExpLeaf
n = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp ([SubExp] -> [TPrimExp Int64 ExpLeaf])
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space