{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}

-- | This module defines a translation from imperative code with
-- kernels to imperative code with OpenCL calls.
module Futhark.CodeGen.ImpGen.Kernels.ToOpenCL
  ( kernelsToOpenCL,
    kernelsToCUDA,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.FileEmbed
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.Kernels hiding (Program)
import qualified Futhark.CodeGen.ImpCode.Kernels as ImpKernels
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import qualified Futhark.CodeGen.ImpCode.OpenCL as ImpOpenCL
import Futhark.Error (compilerLimitationS)
import Futhark.IR.Prop (isBuiltInFunction)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (prettyOneLine)
import qualified Language.C.Quote.CUDA as CUDAC
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C

kernelsToCUDA, kernelsToOpenCL :: ImpKernels.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateKernels KernelTarget
TargetCUDA
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateKernels KernelTarget
TargetOpenCL

-- | Translate a kernels-program to an OpenCL-program.
translateKernels ::
  KernelTarget ->
  ImpKernels.Program ->
  ImpOpenCL.Program
translateKernels :: KernelTarget -> Program -> Program
translateKernels KernelTarget
target Program
prog =
  let ( Definitions OpenCL
prog',
        ToOpenCL Map KernelName (KernelSafety, Func)
kernels Map KernelName (Definition, Func)
device_funs Set PrimType
used_types Map KernelName SizeClass
sizes [FailureMsg]
failures
        ) =
          (State ToOpenCL (Definitions OpenCL)
-> ToOpenCL -> (Definitions OpenCL, ToOpenCL)
forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) (State ToOpenCL (Definitions OpenCL)
 -> (Definitions OpenCL, ToOpenCL))
-> (ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
    -> State ToOpenCL (Definitions OpenCL))
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ReaderT
  (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
-> Functions HostOp -> State ToOpenCL (Definitions OpenCL)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Program -> Functions HostOp
forall a. Definitions a -> Functions a
defFuns Program
prog) (ReaderT
   (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
 -> (Definitions OpenCL, ToOpenCL))
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall a b. (a -> b) -> a -> b
$ do
            let ImpKernels.Definitions
                  (ImpKernels.Constants [Param]
ps Code HostOp
consts)
                  (ImpKernels.Functions [(KernelName, Function HostOp)]
funs) = Program
prog
            Code OpenCL
consts' <- (HostOp
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> Code HostOp
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Code OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget
-> HostOp
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
            [(KernelName, FunctionT OpenCL)]
funs' <- [(KernelName, Function HostOp)]
-> ((KernelName, Function HostOp)
    -> ReaderT
         (Functions HostOp)
         (StateT ToOpenCL Identity)
         (KernelName, FunctionT OpenCL))
-> ReaderT
     (Functions HostOp)
     (StateT ToOpenCL Identity)
     [(KernelName, FunctionT OpenCL)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(KernelName, Function HostOp)]
funs (((KernelName, Function HostOp)
  -> ReaderT
       (Functions HostOp)
       (StateT ToOpenCL Identity)
       (KernelName, FunctionT OpenCL))
 -> ReaderT
      (Functions HostOp)
      (StateT ToOpenCL Identity)
      [(KernelName, FunctionT OpenCL)])
-> ((KernelName, Function HostOp)
    -> ReaderT
         (Functions HostOp)
         (StateT ToOpenCL Identity)
         (KernelName, FunctionT OpenCL))
-> ReaderT
     (Functions HostOp)
     (StateT ToOpenCL Identity)
     [(KernelName, FunctionT OpenCL)]
forall a b. (a -> b) -> a -> b
$ \(KernelName
fname, Function HostOp
fun) ->
              (KernelName
fname,) (FunctionT OpenCL -> (KernelName, FunctionT OpenCL))
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (FunctionT OpenCL)
-> ReaderT
     (Functions HostOp)
     (StateT ToOpenCL Identity)
     (KernelName, FunctionT OpenCL)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HostOp
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> Function HostOp
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (FunctionT OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget
-> HostOp
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
onHostOp KernelTarget
target) Function HostOp
fun

            Definitions OpenCL
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
forall (m :: * -> *) a. Monad m => a -> m a
return (Definitions OpenCL
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL))
-> Definitions OpenCL
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Definitions OpenCL)
forall a b. (a -> b) -> a -> b
$
              Constants OpenCL -> Functions OpenCL -> Definitions OpenCL
forall a. Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
                ([Param] -> Code OpenCL -> Constants OpenCL
forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
                ([(KernelName, FunctionT OpenCL)] -> Functions OpenCL
forall a. [(KernelName, Function a)] -> Functions a
ImpOpenCL.Functions [(KernelName, FunctionT OpenCL)]
funs')

      ([Definition]
device_prototypes, [Func]
device_defs) = [(Definition, Func)] -> ([Definition], [Func])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Definition, Func)] -> ([Definition], [Func]))
-> [(Definition, Func)] -> ([Definition], [Func])
forall a b. (a -> b) -> a -> b
$ Map KernelName (Definition, Func) -> [(Definition, Func)]
forall k a. Map k a -> [a]
M.elems Map KernelName (Definition, Func)
device_funs
      kernels' :: Map KernelName KernelSafety
kernels' = ((KernelSafety, Func) -> KernelSafety)
-> Map KernelName (KernelSafety, Func)
-> Map KernelName KernelSafety
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (KernelSafety, Func) -> KernelSafety
forall a b. (a, b) -> a
fst Map KernelName (KernelSafety, Func)
kernels
      opencl_code :: String
opencl_code = [Func] -> String
openClCode ([Func] -> String) -> [Func] -> String
forall a b. (a -> b) -> a -> b
$ ((KernelSafety, Func) -> Func) -> [(KernelSafety, Func)] -> [Func]
forall a b. (a -> b) -> [a] -> [b]
map (KernelSafety, Func) -> Func
forall a b. (a, b) -> b
snd ([(KernelSafety, Func)] -> [Func])
-> [(KernelSafety, Func)] -> [Func]
forall a b. (a -> b) -> a -> b
$ Map KernelName (KernelSafety, Func) -> [(KernelSafety, Func)]
forall k a. Map k a -> [a]
M.elems Map KernelName (KernelSafety, Func)
kernels

      opencl_prelude :: String
opencl_prelude =
        [String] -> String
unlines
          [ [Definition] -> String
forall a. Pretty a => a -> String
pretty ([Definition] -> String) -> [Definition] -> String
forall a b. (a -> b) -> a -> b
$ KernelTarget -> Set PrimType -> [Definition]
genPrelude KernelTarget
target Set PrimType
used_types,
            [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Definition -> String) -> [Definition] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> String
forall a. Pretty a => a -> String
pretty [Definition]
device_prototypes,
            [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Func -> String) -> [Func] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Func -> String
forall a. Pretty a => a -> String
pretty [Func]
device_defs
          ]
   in String
-> String
-> Map KernelName KernelSafety
-> [PrimType]
-> Map KernelName SizeClass
-> [FailureMsg]
-> Definitions OpenCL
-> Program
ImpOpenCL.Program
        String
opencl_code
        String
opencl_prelude
        Map KernelName KernelSafety
kernels'
        (Set PrimType -> [PrimType]
forall a. Set a -> [a]
S.toList Set PrimType
used_types)
        (Map KernelName SizeClass -> Map KernelName SizeClass
cleanSizes Map KernelName SizeClass
sizes)
        [FailureMsg]
failures
        Definitions OpenCL
prog'
  where
    genPrelude :: KernelTarget -> Set PrimType -> [Definition]
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> [Definition]
genOpenClPrelude
    genPrelude KernelTarget
TargetCUDA = [Definition] -> Set PrimType -> [Definition]
forall a b. a -> b -> a
const [Definition]
genCUDAPrelude

-- | Due to simplifications after kernel extraction, some threshold
-- parameters may contain KernelPaths that reference threshold
-- parameters that no longer exist.  We remove these here.
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map KernelName SizeClass -> Map KernelName SizeClass
cleanSizes Map KernelName SizeClass
m = (SizeClass -> SizeClass)
-> Map KernelName SizeClass -> Map KernelName SizeClass
forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map KernelName SizeClass
m
  where
    known :: [KernelName]
known = Map KernelName SizeClass -> [KernelName]
forall k a. Map k a -> [k]
M.keys Map KernelName SizeClass
m
    clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
      KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (((KernelName, Bool) -> Bool) -> KernelPath -> KernelPath
forall a. (a -> Bool) -> [a] -> [a]
filter ((KernelName -> [KernelName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [KernelName]
known) (KernelName -> Bool)
-> ((KernelName, Bool) -> KernelName) -> (KernelName, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (KernelName, Bool) -> KernelName
forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
    clean SizeClass
s = SizeClass
s

pointerQuals :: Monad m => String -> m [C.TypeQual]
pointerQuals :: String -> m [TypeQual]
pointerQuals String
"global" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__global|]
pointerQuals String
"local" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__local|]
pointerQuals String
"private" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__private|]
pointerQuals String
"constant" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__constant|]
pointerQuals String
"write_only" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__write_only|]
pointerQuals String
"read_only" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__read_only|]
pointerQuals String
"kernel" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__kernel|]
pointerQuals String
s = String -> m [TypeQual]
forall a. HasCallStack => String -> a
error (String -> m [TypeQual]) -> String -> m [TypeQual]
forall a b. (a -> b) -> a -> b
$ String
"'" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' is not an OpenCL kernel address space."

-- In-kernel name and per-workgroup size in bytes.
type LocalMemoryUse = (VName, Count Bytes Exp)

data KernelState = KernelState
  { KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse],
    KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
    KernelState -> Int
kernelNextSync :: Int,
    -- | Has a potential failure occurred sine the last
    -- ErrorSync?
    KernelState -> Bool
kernelSyncPending :: Bool,
    KernelState -> Bool
kernelHasBarriers :: Bool
  }

newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState [LocalMemoryUse]
forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False

errorLabel :: KernelState -> String
errorLabel :: KernelState -> String
errorLabel = (String
"error_" String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String)
-> (KernelState -> String) -> KernelState -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show (Int -> String) -> (KernelState -> Int) -> KernelState -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync

data ToOpenCL = ToOpenCL
  { ToOpenCL -> Map KernelName (KernelSafety, Func)
clKernels :: M.Map KernelName (KernelSafety, C.Func),
    ToOpenCL -> Map KernelName (Definition, Func)
clDevFuns :: M.Map Name (C.Definition, C.Func),
    ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
    ToOpenCL -> Map KernelName SizeClass
clSizes :: M.Map Name SizeClass,
    ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg]
  }

initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map KernelName (KernelSafety, Func)
-> Map KernelName (Definition, Func)
-> Set PrimType
-> Map KernelName SizeClass
-> [FailureMsg]
-> ToOpenCL
ToOpenCL Map KernelName (KernelSafety, Func)
forall a. Monoid a => a
mempty Map KernelName (Definition, Func)
forall a. Monoid a => a
mempty Set PrimType
forall a. Monoid a => a
mempty Map KernelName SizeClass
forall a. Monoid a => a
mempty [FailureMsg]
forall a. Monoid a => a
mempty

type AllFunctions = ImpKernels.Functions ImpKernels.HostOp

lookupFunction :: Name -> AllFunctions -> Maybe ImpKernels.Function
lookupFunction :: KernelName -> Functions HostOp -> Maybe (Function HostOp)
lookupFunction KernelName
fname (ImpKernels.Functions [(KernelName, Function HostOp)]
fs) = KernelName
-> [(KernelName, Function HostOp)] -> Maybe (Function HostOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup KernelName
fname [(KernelName, Function HostOp)]
fs

type OnKernelM = ReaderT AllFunctions (State ToOpenCL)

addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: KernelName -> SizeClass -> OnKernelM ()
addSize KernelName
key SizeClass
sclass =
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map KernelName SizeClass
clSizes = KernelName
-> SizeClass
-> Map KernelName SizeClass
-> Map KernelName SizeClass
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernelName
key SizeClass
sclass (Map KernelName SizeClass -> Map KernelName SizeClass)
-> Map KernelName SizeClass -> Map KernelName SizeClass
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map KernelName SizeClass
clSizes ToOpenCL
s}

onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget
-> HostOp
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget
-> Kernel
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpKernels.GetSize VName
v KernelName
key SizeClass
size_class) = do
  KernelName -> SizeClass -> OnKernelM ()
addSize KernelName
key SizeClass
size_class
  OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> KernelName -> OpenCL
ImpOpenCL.GetSize VName
v KernelName
key
onHostOp KernelTarget
_ (ImpKernels.CmpSizeLe VName
v KernelName
key SizeClass
size_class Exp
x) = do
  KernelName -> SizeClass -> OnKernelM ()
addSize KernelName
key SizeClass
size_class
  OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> KernelName -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v KernelName
key Exp
x
onHostOp KernelTarget
_ (ImpKernels.GetSizeMax VName
v SizeClass
size_class) =
  OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class

genGPUCode ::
  OpsMode ->
  KernelCode ->
  [FailureMsg] ->
  GC.CompilerM KernelOp KernelState a ->
  (a, GC.CompilerState KernelState)
genGPUCode :: OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
mode KernelCode
body [FailureMsg]
failures =
  Operations KernelOp KernelState
-> VNameSource
-> KernelState
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
    (OpsMode -> KernelCode -> Operations KernelOp KernelState
inKernelOperations OpsMode
mode KernelCode
body)
    VNameSource
blankNameSource
    ([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)

-- Compilation of a device function that is not not invoked from the
-- host, but is invoked by (perhaps multiple) kernels.
generateDeviceFun :: Name -> ImpKernels.Function -> OnKernelM ()
generateDeviceFun :: KernelName -> Function HostOp -> OnKernelM ()
generateDeviceFun KernelName
fname Function HostOp
host_func = do
  -- Functions are a priori always considered host-level, so we have
  -- to convert them to device code.  This is where most of our
  -- limitations on device-side functions (no arrays, no parallelism)
  -- comes from.
  let device_func :: FunctionT KernelOp
device_func = (HostOp -> KernelOp) -> Function HostOp -> FunctionT KernelOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
  Bool -> OnKernelM () -> OnKernelM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Param -> Bool) -> [Param] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam ([Param] -> Bool) -> [Param] -> Bool
forall a b. (a -> b) -> a -> b
$ Function HostOp -> [Param]
forall a. FunctionT a -> [Param]
functionInput Function HostOp
host_func) OnKernelM ()
forall a. a
bad

  [FailureMsg]
failures <- (ToOpenCL -> [FailureMsg])
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [FailureMsg]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let params :: [Param]
params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]
      ((Definition, Func)
func, CompilerState KernelState
cstate) =
        OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a.
OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
FunMode (FunctionT KernelOp -> KernelCode
forall a. FunctionT a -> Code a
functionBody FunctionT KernelOp
device_func) [FailureMsg]
failures (CompilerM KernelOp KernelState (Definition, Func)
 -> ((Definition, Func), CompilerState KernelState))
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$
          [BlockItem]
-> [Param]
-> (KernelName, FunctionT KernelOp)
-> CompilerM KernelOp KernelState (Definition, Func)
forall op s.
[BlockItem]
-> [Param]
-> (KernelName, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun [BlockItem]
forall a. Monoid a => a
mempty [Param]
params (KernelName
fname, FunctionT KernelOp
device_func)
      kstate :: KernelState
kstate = CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clUsedTypes :: Set PrimType
clUsedTypes = KernelCode -> Set PrimType
typesInCode (FunctionT KernelOp -> KernelCode
forall a. FunctionT a -> Code a
functionBody FunctionT KernelOp
device_func) Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clDevFuns :: Map KernelName (Definition, Func)
clDevFuns = KernelName
-> (Definition, Func)
-> Map KernelName (Definition, Func)
-> Map KernelName (Definition, Func)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernelName
fname (Definition, Func)
func (Map KernelName (Definition, Func)
 -> Map KernelName (Definition, Func))
-> Map KernelName (Definition, Func)
-> Map KernelName (Definition, Func)
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map KernelName (Definition, Func)
clDevFuns ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- Important to do this after the 'modify' call, so we propagate the
  -- right clFailures.
  ReaderT (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
-> OnKernelM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ReaderT (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
 -> OnKernelM ())
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
-> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ KernelCode
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
ensureDeviceFuns (KernelCode
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) [KernelName])
-> KernelCode
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
forall a b. (a -> b) -> a -> b
$ FunctionT KernelOp -> KernelCode
forall a. FunctionT a -> Code a
functionBody FunctionT KernelOp
device_func
  where
    toDevice :: HostOp -> KernelOp
    toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = KernelOp
forall a. a
bad

    memParam :: Param -> Bool
memParam MemParam {} = Bool
True
    memParam ScalarParam {} = Bool
False

    bad :: a
bad = String -> a
forall a. String -> a
compilerLimitationS String
"Cannot generate GPU functions that use arrays."

-- Ensure that this device function is available, but don't regenerate
-- it if it already exists.
ensureDeviceFun :: Name -> ImpKernels.Function -> OnKernelM ()
ensureDeviceFun :: KernelName -> Function HostOp -> OnKernelM ()
ensureDeviceFun KernelName
fname Function HostOp
host_func = do
  Bool
exists <- (ToOpenCL -> Bool)
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ToOpenCL -> Bool)
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) Bool)
-> (ToOpenCL -> Bool)
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) Bool
forall a b. (a -> b) -> a -> b
$ KernelName -> Map KernelName (Definition, Func) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member KernelName
fname (Map KernelName (Definition, Func) -> Bool)
-> (ToOpenCL -> Map KernelName (Definition, Func))
-> ToOpenCL
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map KernelName (Definition, Func)
clDevFuns
  Bool -> OnKernelM () -> OnKernelM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (OnKernelM () -> OnKernelM ()) -> OnKernelM () -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ KernelName -> Function HostOp -> OnKernelM ()
generateDeviceFun KernelName
fname Function HostOp
host_func

ensureDeviceFuns :: ImpKernels.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: KernelCode
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
ensureDeviceFuns KernelCode
code = do
  let called :: Set KernelName
called = KernelCode -> Set KernelName
forall a. Code a -> Set KernelName
calledFuncs KernelCode
code
  ([Maybe KernelName] -> [KernelName])
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName]
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelName] -> [KernelName]
forall a. [Maybe a] -> [a]
catMaybes (ReaderT
   (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName]
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) [KernelName])
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName]
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
forall a b. (a -> b) -> a -> b
$
    [KernelName]
-> (KernelName
    -> ReaderT
         (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName))
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Set KernelName -> [KernelName]
forall a. Set a -> [a]
S.toList Set KernelName
called) ((KernelName
  -> ReaderT
       (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName))
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName])
-> (KernelName
    -> ReaderT
         (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName))
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [Maybe KernelName]
forall a b. (a -> b) -> a -> b
$ \KernelName
fname -> do
      Maybe (Function HostOp)
def <- (Functions HostOp -> Maybe (Function HostOp))
-> ReaderT
     (Functions HostOp)
     (StateT ToOpenCL Identity)
     (Maybe (Function HostOp))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Functions HostOp -> Maybe (Function HostOp))
 -> ReaderT
      (Functions HostOp)
      (StateT ToOpenCL Identity)
      (Maybe (Function HostOp)))
-> (Functions HostOp -> Maybe (Function HostOp))
-> ReaderT
     (Functions HostOp)
     (StateT ToOpenCL Identity)
     (Maybe (Function HostOp))
forall a b. (a -> b) -> a -> b
$ KernelName -> Functions HostOp -> Maybe (Function HostOp)
lookupFunction KernelName
fname
      case Maybe (Function HostOp)
def of
        Just Function HostOp
func -> do
          KernelName -> Function HostOp -> OnKernelM ()
ensureDeviceFun KernelName
fname Function HostOp
func
          Maybe KernelName
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelName
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName))
-> Maybe KernelName
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName)
forall a b. (a -> b) -> a -> b
$ KernelName -> Maybe KernelName
forall a. a -> Maybe a
Just KernelName
fname
        Maybe (Function HostOp)
Nothing -> Maybe KernelName
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) (Maybe KernelName)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelName
forall a. Maybe a
Nothing

onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget
-> Kernel
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
onKernel KernelTarget
target Kernel
kernel = do
  [KernelName]
called <- KernelCode
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
ensureDeviceFuns (KernelCode
 -> ReaderT
      (Functions HostOp) (StateT ToOpenCL Identity) [KernelName])
-> KernelCode
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [KernelName]
forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel

  -- Crucial that this is done after 'ensureDeviceFuns', as the device
  -- functions may themselves define failure points.
  [FailureMsg]
failures <- (ToOpenCL -> [FailureMsg])
-> ReaderT
     (Functions HostOp) (StateT ToOpenCL Identity) [FailureMsg]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
        OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState [BlockItem]
-> ([BlockItem], CompilerState KernelState)
forall a.
OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
KernelMode (Kernel -> KernelCode
kernelBody Kernel
kernel) [FailureMsg]
failures (CompilerM KernelOp KernelState [BlockItem]
 -> ([BlockItem], CompilerState KernelState))
-> CompilerM KernelOp KernelState [BlockItem]
-> ([BlockItem], CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$
          CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.blockScope (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState [BlockItem])
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall a b. (a -> b) -> a -> b
$ KernelCode -> CompilerM KernelOp KernelState ()
forall op s. Code op -> CompilerM op s ()
GC.compileCode (KernelCode -> CompilerM KernelOp KernelState ())
-> KernelCode -> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel
      kstate :: KernelState
kstate = CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

      use_params :: [Param]
use_params = (KernelUse -> Maybe Param) -> [KernelUse] -> [Param]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe Param
useAsParam ([KernelUse] -> [Param]) -> [KernelUse] -> [Param]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

      ([Maybe KernelArg]
local_memory_args, [Maybe Param]
local_memory_params, [BlockItem]
local_memory_init) =
        [(Maybe KernelArg, Maybe Param, BlockItem)]
-> ([Maybe KernelArg], [Maybe Param], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Maybe KernelArg, Maybe Param, BlockItem)]
 -> ([Maybe KernelArg], [Maybe Param], [BlockItem]))
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
-> ([Maybe KernelArg], [Maybe Param], [BlockItem])
forall a b. (a -> b) -> a -> b
$
          (State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
 -> VNameSource -> [(Maybe KernelArg, Maybe Param, BlockItem)])
-> VNameSource
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> VNameSource -> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall s a. State s a -> s -> a
evalState (VNameSource
blankNameSource :: VNameSource) (State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
 -> [(Maybe KernelArg, Maybe Param, BlockItem)])
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b. (a -> b) -> a -> b
$
            (LocalMemoryUse
 -> StateT
      VNameSource Identity (Maybe KernelArg, Maybe Param, BlockItem))
-> [LocalMemoryUse]
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (KernelTarget
-> LocalMemoryUse
-> StateT
     VNameSource Identity (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *).
MonadFreshNames m =>
KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
target) ([LocalMemoryUse]
 -> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)])
-> [LocalMemoryUse]
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b. (a -> b) -> a -> b
$ KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate

      -- CUDA has very strict restrictions on the number of blocks
      -- permitted along the 'y' and 'z' dimensions of the grid
      -- (1<<16).  To work around this, we are going to dynamically
      -- permute the block dimensions to move the largest one to the
      -- 'x' dimension, which has a higher limit (1<<31).  This means
      -- we need to extend the kernel with extra parameters that
      -- contain information about this permutation, but we only do
      -- this for multidimensional kernels (at the time of this
      -- writing, only transposes).  The corresponding arguments are
      -- added automatically in CCUDA.hs.
      ([Param]
perm_params, [BlockItem]
block_dim_init) =
        case (KernelTarget
target, [Exp]
num_groups) of
          (KernelTarget
TargetCUDA, [Exp
_, Exp
_, Exp
_]) ->
            ( [ [C.cparam|const int block_dim0|],
                [C.cparam|const int block_dim1|],
                [C.cparam|const int block_dim2|]
              ],
              [BlockItem]
forall a. Monoid a => a
mempty
            )
          (KernelTarget, [Exp])
_ ->
            ( [Param]
forall a. Monoid a => a
mempty,
              [ [C.citem|const int block_dim0 = 0;|],
                [C.citem|const int block_dim1 = 1;|],
                [C.citem|const int block_dim2 = 2;|]
              ]
            )

      ([BlockItem]
const_defs, [BlockItem]
const_undefs) = [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (KernelUse -> Maybe (BlockItem, BlockItem))
-> [KernelUse] -> [(BlockItem, BlockItem)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (BlockItem, BlockItem)
constDef ([KernelUse] -> [(BlockItem, BlockItem)])
-> [KernelUse] -> [(BlockItem, BlockItem)]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let (KernelSafety
safety, [BlockItem]
error_init)
        -- We conservatively assume that any called function can fail.
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [KernelName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [KernelName]
called =
          (KernelSafety
SafetyFull, [])
        | [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
          if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
            then (KernelSafety
SafetyNone, [])
            else -- No possible failures in this kernel, so if we make
            -- it past an initial check, then we are good to go.

              ( KernelSafety
SafetyCheap,
                [C.citems|if (*global_failure >= 0) { return; }|]
              )
        | Bool
otherwise =
          if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
            then
              ( KernelSafety
SafetyFull,
                [C.citems|if (*global_failure >= 0) { return; }|]
              )
            else
              ( KernelSafety
SafetyFull,
                [C.citems|
                     volatile __local bool local_failure;
                     if (failure_is_an_option) {
                       int failed = *global_failure >= 0;
                       if (failed) {
                         return;
                       }
                     }
                     // All threads write this value - it looks like CUDA has a compiler bug otherwise.
                     local_failure = false;
                     barrier(CLK_LOCAL_MEM_FENCE);
                  |]
              )

      failure_params :: [Param]
failure_params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|int failure_is_an_option|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]

      params :: [Param]
params =
        [Param]
perm_params
          [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ Int -> [Param] -> [Param]
forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
          [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Maybe Param] -> [Param]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
local_memory_params
          [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
use_params

      kernel_fun :: Func
kernel_fun =
        [C.cfun|__kernel void $id:name ($params:params) {
                  $items:const_defs
                  $items:block_dim_init
                  $items:local_memory_init
                  $items:error_init
                  $items:kernel_body

                  $id:(errorLabel kstate): return;

                  $items:const_undefs
                }|]
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clKernels :: Map KernelName (KernelSafety, Func)
clKernels = KernelName
-> (KernelSafety, Func)
-> Map KernelName (KernelSafety, Func)
-> Map KernelName (KernelSafety, Func)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernelName
name (KernelSafety
safety, Func
kernel_fun) (Map KernelName (KernelSafety, Func)
 -> Map KernelName (KernelSafety, Func))
-> Map KernelName (KernelSafety, Func)
-> Map KernelName (KernelSafety, Func)
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map KernelName (KernelSafety, Func)
clKernels ToOpenCL
s,
        clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- The argument corresponding to the global_failure parameters is
  -- added automatically later.
  let args :: [KernelArg]
args =
        [Maybe KernelArg] -> [KernelArg]
forall a. [Maybe a] -> [a]
catMaybes [Maybe KernelArg]
local_memory_args
          [KernelArg] -> [KernelArg] -> [KernelArg]
forall a. [a] -> [a] -> [a]
++ Kernel -> [KernelArg]
kernelArgs Kernel
kernel

  OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL
 -> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL)
-> OpenCL
-> ReaderT (Functions HostOp) (StateT ToOpenCL Identity) OpenCL
forall a b. (a -> b) -> a -> b
$ KernelSafety
-> KernelName -> [KernelArg] -> [Exp] -> [Exp] -> OpenCL
LaunchKernel KernelSafety
safety KernelName
name [KernelArg]
args [Exp]
num_groups [Exp]
group_size
  where
    name :: KernelName
name = Kernel -> KernelName
kernelName Kernel
kernel
    num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
    group_size :: [Exp]
group_size = Kernel -> [Exp]
kernelGroupSize Kernel
kernel

    prepareLocalMemory :: KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
TargetOpenCL (VName
mem, Count Bytes Exp
size) = do
      VName
mem_aligned <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_aligned"
      (Maybe KernelArg, Maybe Param, BlockItem)
-> m (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
          [C.citem|__local volatile char* restrict $id:mem = (__local volatile char*)$id:mem_aligned;|]
        )
    prepareLocalMemory KernelTarget
TargetCUDA (VName
mem, Count Bytes Exp
size) = do
      VName
param <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_offset"
      (Maybe KernelArg, Maybe Param, BlockItem)
-> m (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|uint $id:param|],
          [C.citem|volatile char *$id:mem = &shared_mem[$id:param];|]
        )

useAsParam :: KernelUse -> Maybe C.Param
useAsParam :: KernelUse -> Maybe Param
useAsParam (ScalarUse VName
name PrimType
bt) =
  let ctp :: Type
ctp = case PrimType
bt of
        -- OpenCL does not permit bool as a kernel parameter type.
        PrimType
Bool -> [C.cty|unsigned char|]
        PrimType
_ -> PrimType -> Type
GC.primTypeToCType PrimType
bt
   in Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|$ty:ctp $id:name|]
useAsParam (MemoryUse VName
name) =
  Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|__global unsigned char *$id:name|]
useAsParam ConstUse {} =
  Maybe Param
forall a. Maybe a
Nothing

-- Constants are #defined as macros.  Since a constant name in one
-- kernel might potentially (although unlikely) also be used for
-- something else in another kernel, we #undef them after the kernel.
constDef :: KernelUse -> Maybe (C.BlockItem, C.BlockItem)
constDef :: KernelUse -> Maybe (BlockItem, BlockItem)
constDef (ConstUse VName
v KernelConstExp
e) =
  (BlockItem, BlockItem) -> Maybe (BlockItem, BlockItem)
forall a. a -> Maybe a
Just
    ( [C.citem|$escstm:def|],
      [C.citem|$escstm:undef|]
    )
  where
    e' :: Exp
e' = KernelConstExp -> Exp
compilePrimExp KernelConstExp
e
    def :: String
def = String
"#define " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall a. Pretty a => a -> String
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
prettyOneLine Exp
e' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
    undef :: String
undef = String
"#undef " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall a. Pretty a => a -> String
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty)
constDef KernelUse
_ = Maybe (BlockItem, BlockItem)
forall a. Maybe a
Nothing

openClCode :: [C.Func] -> String
openClCode :: [Func] -> String
openClCode [Func]
kernels =
  [Definition] -> String
forall a. Pretty a => a -> String
pretty [C.cunit|$edecls:funcs|]
  where
    funcs :: [Definition]
funcs =
      [ [C.cedecl|$func:kernel_func|]
        | Func
kernel_func <- [Func]
kernels
      ]

atomicsDefs :: String
atomicsDefs :: String
atomicsDefs = $(embedStringFile "rts/c/atomics.h")

genOpenClPrelude :: S.Set PrimType -> [C.Definition]
genOpenClPrelude :: Set PrimType -> [Definition]
genOpenClPrelude Set PrimType
ts =
  -- Clang-based OpenCL implementations need this for 'static' to work.
  [ [C.cedecl|$esc:("#ifdef cl_clang_storage_class_specifiers")|],
    [C.cedecl|$esc:("#pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable")|],
    [C.cedecl|$esc:("#endif")|],
    [C.cedecl|$esc:("#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable")|]
  ]
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [[Definition]] -> [Definition]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
      [ [C.cunit|$esc:("#pragma OPENCL EXTENSION cl_khr_fp64 : enable")
                 $esc:("#define FUTHARK_F64_ENABLED")|]
        | Bool
uses_float64
      ]
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [C.cunit|
/* Some OpenCL programs dislike empty progams, or programs with no kernels.
 * Declare a dummy kernel to ensure they remain our friends. */
__kernel void dummy_kernel(__global unsigned char *dummy, int n)
{
    const int thread_gid = get_global_id(0);
    if (thread_gid >= n) return;
}

$esc:("#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable")
$esc:("#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable")

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long int64_t;

typedef uchar uint8_t;
typedef ushort uint16_t;
typedef uint uint32_t;
typedef ulong uint64_t;

// NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we
// use inline assembly if we detect we are on an NVIDIA GPU.
$esc:("#ifdef cl_nv_pragma_unroll")
static inline void mem_fence_global() {
  asm("membar.gl;");
}
$esc:("#else")
static inline void mem_fence_global() {
  mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
}
$esc:("#endif")
static inline void mem_fence_local() {
  mem_fence(CLK_LOCAL_MEM_FENCE);
}
|]
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cIntOps
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ (if Bool
uses_float64 then [Definition]
cFloat64Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps else [])
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [[C.cedecl|$esc:atomicsDefs|]]
  where
    uses_float64 :: Bool
uses_float64 = FloatType -> PrimType
FloatType FloatType
Float64 PrimType -> Set PrimType -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts

genCUDAPrelude :: [C.Definition]
genCUDAPrelude :: [Definition]
genCUDAPrelude =
  [Definition]
cudafy [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
ops
  where
    ops :: [Definition]
ops =
      [Definition]
cIntOps [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Ops
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [[C.cedecl|$esc:atomicsDefs|]]
    cudafy :: [Definition]
cudafy =
      [CUDAC.cunit|
$esc:("#define FUTHARK_CUDA")
$esc:("#define FUTHARK_F64_ENABLED")

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long long int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef uint8_t uchar;
typedef uint16_t ushort;
typedef uint32_t uint;
typedef uint64_t ulong;
$esc:("#define __kernel extern \"C\" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)")
$esc:("#define __global")
$esc:("#define __local")
$esc:("#define __private")
$esc:("#define __constant")
$esc:("#define __write_only")
$esc:("#define __read_only")

static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch (d) {
    case 0: return blockIdx.x;
    case 1: return blockIdx.y;
    case 2: return blockIdx.z;
    default: return 0;
  }
}
$esc:("#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch(d) {
    case 0: return gridDim.x;
    case 1: return gridDim.y;
    case 2: return gridDim.z;
    default: return 0;
  }
}
$esc:("#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_local_id(int d)
{
  switch (d) {
    case 0: return threadIdx.x;
    case 1: return threadIdx.y;
    case 2: return threadIdx.z;
    default: return 0;
  }
}

static inline int get_local_size(int d)
{
  switch (d) {
    case 0: return blockDim.x;
    case 1: return blockDim.y;
    case 2: return blockDim.z;
    default: return 0;
  }
}

static inline int get_global_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  return get_group_id(d) * get_local_size(d) + get_local_id(d);
}
$esc:("#define get_global_id(d) get_global_id_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_global_size(int block_dim0, int block_dim1, int block_dim2, int d)
{
  return get_num_groups(d) * get_local_size(d);
}

$esc:("#define CLK_LOCAL_MEM_FENCE 1")
$esc:("#define CLK_GLOBAL_MEM_FENCE 2")
static inline void barrier(int x)
{
  __syncthreads();
}
static inline void mem_fence_local() {
  __threadfence_block();
}
static inline void mem_fence_global() {
  __threadfence();
}

$esc:("#define NAN (0.0/0.0)")
$esc:("#define INFINITY (1.0/0.0)")
extern volatile __shared__ char shared_mem[];
|]

compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp :: KernelConstExp -> Exp
compilePrimExp KernelConstExp
e = Identity Exp -> Exp
forall a. Identity a -> a
runIdentity (Identity Exp -> Exp) -> Identity Exp -> Exp
forall a b. (a -> b) -> a -> b
$ (KernelConst -> Identity Exp) -> KernelConstExp -> Identity Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GC.compilePrimExp KernelConst -> Identity Exp
forall (m :: * -> *). Monad m => KernelConst -> m Exp
compileKernelConst KernelConstExp
e
  where
    compileKernelConst :: KernelConst -> m Exp
compileKernelConst (SizeConst KernelName
key) =
      Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(zEncodeString (pretty key))|]

kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = (KernelUse -> Maybe KernelArg) -> [KernelUse] -> [KernelArg]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg ([KernelUse] -> [KernelArg])
-> (Kernel -> [KernelUse]) -> Kernel -> [KernelArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
  where
    useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
    useToArg (ScalarUse VName
v PrimType
bt) = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> ExpLeaf
ScalarVar VName
v) PrimType
bt) PrimType
bt
    useToArg ConstUse {} = Maybe KernelArg
forall a. Maybe a
Nothing

nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState String
nextErrorLabel =
  KernelState -> String
errorLabel (KernelState -> String)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState

incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}

hasCommunication :: ImpKernels.KernelCode -> Bool
hasCommunication :: KernelCode -> Bool
hasCommunication = (KernelOp -> Bool) -> KernelCode -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
  where
    communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
    communicates Barrier {} = Bool
True
    communicates KernelOp
_ = Bool
False

-- Whether we are generating code for a kernel or a device function.
-- This has minor effects, such as exactly how failures are
-- propagated.
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
(OpsMode -> OpsMode -> Bool)
-> (OpsMode -> OpsMode -> Bool) -> Eq OpsMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c== :: OpsMode -> OpsMode -> Bool
Eq)

inKernelOperations ::
  OpsMode ->
  ImpKernels.KernelCode ->
  GC.Operations KernelOp KernelState
inKernelOperations :: OpsMode -> KernelCode -> Operations KernelOp KernelState
inKernelOperations OpsMode
mode KernelCode
body =
  Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Deallocate op s
-> Copy op s
-> StaticArray op s
-> MemoryType op s
-> OpCompiler op s
-> ErrorCompiler op s
-> CallCompiler op s
-> Bool
-> ([BlockItem], [BlockItem])
-> Operations op s
GC.Operations
    { opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
      opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = MemoryType KernelOp KernelState
forall (m :: * -> *). Monad m => String -> m Type
kernelMemoryType,
      opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = WriteScalar KernelOp KernelState
forall op s. WriteScalar op s
kernelWriteScalar,
      opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = ReadScalar KernelOp KernelState
forall op s. ReadScalar op s
kernelReadScalar,
      opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
      opsDeallocate :: Deallocate KernelOp KernelState
GC.opsDeallocate = Deallocate KernelOp KernelState
cannotDeallocate,
      opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
      opsStaticArray :: StaticArray KernelOp KernelState
GC.opsStaticArray = StaticArray KernelOp KernelState
noStaticArrays,
      opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
      opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
      opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = CallCompiler KernelOp KernelState
callInKernel,
      opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = ([BlockItem], [BlockItem])
forall a. Monoid a => a
mempty
    }
  where
    has_communication :: Bool
has_communication = KernelCode -> Bool
hasCommunication KernelCode
body

    fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
    fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]

    kernelOps :: GC.OpCompiler KernelOp KernelState
    kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_group_id($int:i);|]
    kernelOps (GetLocalId VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
    kernelOps (GetLocalSize VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
    kernelOps (GetGlobalId VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_global_id($int:i);|]
    kernelOps (GetGlobalSize VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_global_size($int:i);|]
    kernelOps (GetLockstepWidth VName
v) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
    kernelOps (Barrier Fence
f) = do
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
    kernelOps (MemFence Fence
FenceLocal) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
    kernelOps (MemFence Fence
FenceGlobal) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
    kernelOps (LocalAlloc VName
name Count Bytes (TExp Int64)
size) = do
      VName
name' <- String -> CompilerM KernelOp KernelState VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM KernelOp KernelState VName)
-> String -> CompilerM KernelOp KernelState VName
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_backing"
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', (TExp Int64 -> Exp) -> Count Bytes (TExp Int64) -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Count Bytes (TExp Int64)
size) LocalMemoryUse -> [LocalMemoryUse] -> [LocalMemoryUse]
forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s}
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local char*) $id:name';|]
    kernelOps (ErrorSync Fence
f) = do
      String
label <- CompilerM KernelOp KernelState String
nextErrorLabel
      Bool
pending <- KernelState -> Bool
kernelSyncPending (KernelState -> Bool)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState
      Bool
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState ())
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ do
        Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
        Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
        Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier(CLK_LOCAL_MEM_FENCE);|] -- intentional
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
      CompilerM KernelOp KernelState ()
incErrorLabel
    kernelOps (Atomic Space
space AtomicOp
aop) = Space -> AtomicOp -> CompilerM KernelOp KernelState ()
forall op s. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop

    atomicCast :: Space -> Type -> m Type
atomicCast Space
s Type
t = do
      let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
      [TypeQual]
quals <- case Space
s of
        Space String
sid -> String -> m [TypeQual]
forall (m :: * -> *). Monad m => String -> m [TypeQual]
pointerQuals String
sid
        Space
_ -> String -> m [TypeQual]
forall (m :: * -> *). Monad m => String -> m [TypeQual]
pointerQuals String
"global"
      Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$tyquals:(volatile++quals) $ty:t|]

    atomicSpace :: Space -> String
atomicSpace (Space String
sid) = String
sid
    atomicSpace Space
_ = String
"global"

    doAtomic :: Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s a
t a
old a
arr Count u (TPrimExp t ExpLeaf)
ind Exp
val String
op Type
ty = do
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t ExpLeaf -> Exp) -> TPrimExp t ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t ExpLeaf) -> TPrimExp t ExpLeaf
forall u e. Count u e -> e
unCount Count u (TPrimExp t ExpLeaf)
ind
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
      where
        op' :: String
op' = String
op String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
atomicSpace Space
s

    doAtomicCmpXchg :: Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s a
t a
old a
arr Count u (TPrimExp t ExpLeaf)
ind Exp
cmp Exp
val Type
ty = do
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t ExpLeaf -> Exp) -> TPrimExp t ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t ExpLeaf) -> TPrimExp t ExpLeaf
forall u e. Count u e -> e
unCount Count u (TPrimExp t ExpLeaf)
ind
      Exp
cmp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
      where
        op :: String
op = String
"atomic_cmpxchg_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
atomicSpace Space
s
    doAtomicXchg :: Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s a
t a
old a
arr Count u (TPrimExp t ExpLeaf)
ind Exp
val Type
ty = do
      Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t ExpLeaf -> Exp) -> TPrimExp t ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t ExpLeaf) -> TPrimExp t ExpLeaf
forall u e. Count u e -> e
unCount Count u (TPrimExp t ExpLeaf)
ind
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
      where
        op :: String
op = String
"atomic_chg_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
atomicSpace Space
s
    -- First the 64-bit operations.
    atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_add" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> FloatType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_fadd" [C.cty|double|]
    atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_smax" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_smin" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_umax" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_umin" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_and" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_or" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_xor" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|typename int64_t|]
    --
    atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_add" [C.cty|int|]
    atomicOps Space
s (AtomicFAdd FloatType
Float32 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> FloatType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float32 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_fadd" [C.cty|float|]
    atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_smax" [C.cty|int|]
    atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_smin" [C.cty|int|]
    atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_umax" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_umin" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_and" [C.cty|int|]
    atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_or" [C.cty|int|]
    atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> String
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> String
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val String
"atomic_xor" [C.cty|int|]
    atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|int|]
    atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Type
-> CompilerM op s ()
forall a a a u t op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u (TPrimExp t ExpLeaf)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|int|]

    cannotAllocate :: GC.Allocate KernelOp KernelState
    cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
      String -> Deallocate KernelOp KernelState
forall a. HasCallStack => String -> a
error String
"Cannot allocate memory in kernel"

    cannotDeallocate :: GC.Deallocate KernelOp KernelState
    cannotDeallocate :: Deallocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
      String -> String -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => String -> a
error String
"Cannot deallocate memory in kernel"

    copyInKernel :: GC.Copy KernelOp KernelState
    copyInKernel :: Copy KernelOp KernelState
copyInKernel Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      String -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => String -> a
error String
"Cannot bulk copy in kernel."

    noStaticArrays :: GC.StaticArray KernelOp KernelState
    noStaticArrays :: StaticArray KernelOp KernelState
noStaticArrays VName
_ String
_ PrimType
_ ArrayContents
_ =
      String -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => String -> a
error String
"Cannot create static array in kernel."

    kernelMemoryType :: String -> m Type
kernelMemoryType String
space = do
      [TypeQual]
quals <- String -> m [TypeQual]
forall (m :: * -> *). Monad m => String -> m [TypeQual]
pointerQuals String
space
      Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$tyquals:quals $ty:defaultMemBlockType|]

    kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
      PointerQuals op s -> WriteScalar op s
forall op s. PointerQuals op s -> WriteScalar op s
GC.writeScalarPointerWithQuals PointerQuals op s
forall (m :: * -> *). Monad m => String -> m [TypeQual]
pointerQuals

    kernelReadScalar :: ReadScalar op s
kernelReadScalar =
      PointerQuals op s -> ReadScalar op s
forall op s. PointerQuals op s -> ReadScalar op s
GC.readScalarPointerWithQuals PointerQuals op s
forall (m :: * -> *). Monad m => String -> m [TypeQual]
pointerQuals

    whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
      String
label <- CompilerM KernelOp KernelState String
nextErrorLabel
      Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
      [BlockItem] -> CompilerM KernelOp KernelState [BlockItem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BlockItem] -> CompilerM KernelOp KernelState [BlockItem])
-> [BlockItem] -> CompilerM KernelOp KernelState [BlockItem]
forall a b. (a -> b) -> a -> b
$
        if Bool
has_communication
          then [C.citems|local_failure = true; goto $id:label;|]
          else
            if OpsMode
mode OpsMode -> OpsMode -> Bool
forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
              then [C.citems|return 1;|]
              else [C.citems|return;|]

    callInKernel :: CallCompiler KernelOp KernelState
callInKernel [VName]
dests KernelName
fname [Exp]
args
      | KernelName -> Bool
isBuiltInFunction KernelName
fname =
        Operations KernelOp KernelState
-> CallCompiler KernelOp KernelState
forall op s. Operations op s -> CallCompiler op s
GC.opsCall Operations KernelOp KernelState
forall op s. Operations op s
GC.defaultOperations [VName]
dests KernelName
fname [Exp]
args
      | Bool
otherwise = do
        let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | VName
d <- [VName]
dests]
            args' :: [Exp]
args' =
              [C.cexp|global_failure|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
              [C.cexp|global_failure_args|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
              [Exp]
out_args [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args

        [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

        BlockItem -> CompilerM KernelOp KernelState ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]

    errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) String
backtrace = do
      Int
n <- [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([FailureMsg] -> Int)
-> (KernelState -> [FailureMsg]) -> KernelState -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures (KernelState -> Int)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s [FailureMsg] -> [FailureMsg] -> [FailureMsg]
forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> String -> FailureMsg
FailureMsg ErrorMsg Exp
msg String
backtrace]}
      let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return []
          setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
          setArgs a
i (ErrorInt32 Exp
x : [ErrorMsgPart Exp]
parts') = do
            Exp
x' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
            [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
            [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm] -> CompilerM op s [Stm]) -> [Stm] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
          setArgs a
i (ErrorInt64 Exp
x : [ErrorMsgPart Exp]
parts') = do
            Exp
x' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
            [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
            [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm] -> CompilerM op s [Stm]) -> [Stm] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = $exp:x';|] Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
      [Stm]
argstms <- Int -> [ErrorMsgPart Exp] -> CompilerM KernelOp KernelState [Stm]
forall a op s.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts

      [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
                                 { $stms:argstms; }
                                 $items:what_next
                               }|]

--- Checking requirements

typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = KernelCode -> Set PrimType
typesInCode (KernelCode -> Set PrimType) -> KernelCode -> Set PrimType
forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel

typesInCode :: ImpKernels.KernelCode -> S.Set PrimType
typesInCode :: KernelCode -> Set PrimType
typesInCode KernelCode
Skip = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (KernelCode
c1 :>>: KernelCode
c2) = KernelCode -> Set PrimType
typesInCode KernelCode
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (For VName
_ Exp
e KernelCode
c) = Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (While (TPrimExp Exp
e) KernelCode
c) = Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode DeclareMem {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ Space
_ PrimType
t ArrayContents
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode
  ( Copy
      VName
_
      (Count (TPrimExp Exp
e1))
      Space
_
      VName
_
      (Count (TPrimExp Exp
e2))
      Space
_
      (Count (TPrimExp Exp
e3))
    ) =
    Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e3
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
  Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ KernelName
_ [Arg]
es) = [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ([Set PrimType] -> Set PrimType) -> [Set PrimType] -> Set PrimType
forall a b. (a -> b) -> a -> b
$ (Arg -> Set PrimType) -> [Arg] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
  where
    typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = Set PrimType
forall a. Monoid a => a
mempty
    typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) KernelCode
c1 KernelCode
c2) =
  Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment String
_ KernelCode
c) = KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (DebugPrint String
_ Maybe Exp
v) = Set PrimType -> (Exp -> Set PrimType) -> Maybe Exp -> Set PrimType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set PrimType
forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode Op {} = Set PrimType
forall a. Monoid a => a
mempty

typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton (PrimType -> Set PrimType) -> PrimType -> Set PrimType
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = [PrimType] -> Set PrimType
forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
  where
    (PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp String
_ [Exp]
args PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ((Exp -> Set PrimType) -> [Exp] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp (LeafExp (Index VName
_ (Count (TPrimExp Exp
e)) PrimType
t Space
_ Volatility
_) PrimType
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
typesInExp (LeafExp ScalarVar {} PrimType
_) = Set PrimType
forall a. Monoid a => a
mempty
typesInExp (LeafExp (SizeOf PrimType
t) PrimType
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t