{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen
(
compileProg,
OpCompiler,
ExpCompiler,
CopyCompiler,
StmsCompiler,
AllocCompiler,
Operations (..),
defaultOperations,
MemLocation (..),
MemEntry (..),
ScalarEntry (..),
ImpM,
localDefaultSpace,
askFunction,
newVNameForFun,
nameForFun,
askEnv,
localEnv,
localOps,
VTable,
getVTable,
localVTable,
subImpM,
subImpM_,
emit,
emitFunction,
hasFunction,
collect,
collect',
comment,
VarEntry (..),
ArrayEntry (..),
lookupVar,
lookupArray,
lookupMemory,
lookupAcc,
TV,
mkTV,
tvSize,
tvExp,
tvVar,
ToExp (..),
compileAlloc,
everythingVolatile,
compileBody,
compileBody',
compileLoopBody,
defCompileStms,
compileStms,
compileExp,
defCompileExp,
fullyIndexArray,
fullyIndexArray',
copy,
copyDWIM,
copyDWIMFix,
copyElementWise,
typeSize,
inBounds,
isMapTransposeCopy,
dLParams,
dFParams,
dScope,
dArray,
dPrim,
dPrimVol,
dPrim_,
dPrimV_,
dPrimV,
dPrimVE,
sFor,
sWhile,
sComment,
sIf,
sWhen,
sUnless,
sOp,
sDeclareMem,
sAlloc,
sAlloc_,
sArray,
sArrayInMem,
sAllocArray,
sAllocArrayPerm,
sStaticArray,
sWrite,
sUpdate,
sLoopNest,
(<--),
(<~~),
function,
warn,
module Language.Futhark.Warnings,
)
where
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.String
import Futhark.CodeGen.ImpCode
( Bytes,
Count,
Elements,
bytes,
elements,
withElemType,
)
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()
type CopyCompiler lore r op =
PrimType ->
MemLocation ->
Slice (Imp.TExp Int64) ->
MemLocation ->
Slice (Imp.TExp Int64) ->
ImpM lore r op ()
type AllocCompiler lore r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM lore r op ()
data Operations lore r op = Operations
{ forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op,
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op,
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op,
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op,
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
}
defaultOperations ::
(Mem lore, FreeIn op) =>
OpCompiler lore r op ->
Operations lore r op
defaultOperations :: forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc =
Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
{ opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp,
opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc,
opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy,
opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
}
data MemLocation = MemLocation
{ MemLocation -> VName
memLocationName :: VName,
MemLocation -> [SubExp]
memLocationShape :: [Imp.DimSize],
MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
}
deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> [Char]
(Int -> MemLocation -> ShowS)
-> (MemLocation -> [Char])
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> [Char]
$cshow :: MemLocation -> [Char]
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)
data ArrayEntry = ArrayEntry
{ ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
ArrayEntry -> PrimType
entryArrayElemType :: PrimType
}
deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)
entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLocation -> [SubExp]
memLocationShape (MemLocation -> [SubExp])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation
newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)
newtype ScalarEntry = ScalarEntry
{ ScalarEntry -> PrimType
entryScalarType :: PrimType
}
deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)
data VarEntry lore
= ArrayVar (Maybe (Exp lore)) ArrayEntry
| ScalarVar (Maybe (Exp lore)) ScalarEntry
| MemVar (Maybe (Exp lore)) MemEntry
| AccVar (Maybe (Exp lore)) (VName, Shape, [Type])
deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> [Char]
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> [Char])
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> [Char]
$cshow :: forall lore. Decorations lore => VarEntry lore -> [Char]
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)
data Destination = Destination
{ Destination -> Maybe Int
destinationTag :: Maybe Int,
Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
}
deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> [Char]
(Int -> Destination -> ShowS)
-> (Destination -> [Char])
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> [Char]
$cshow :: Destination -> [Char]
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)
data ValueDestination
= ScalarDestination VName
| MemoryDestination VName
|
ArrayDestination (Maybe MemLocation)
deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)
data Env lore r op = Env
{ forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op,
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op,
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op,
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op,
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op),
forall lore r op. Env lore r op -> Space
envDefaultSpace :: Imp.Space,
forall lore r op. Env lore r op -> Volatility
envVolatility :: Imp.Volatility,
forall lore r op. Env lore r op -> r
envEnv :: r,
forall lore r op. Env lore r op -> Maybe Name
envFunction :: Maybe Name,
forall lore r op. Env lore r op -> Attrs
envAttrs :: Attrs
}
newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env
{ envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty,
envDefaultSpace :: Space
envDefaultSpace = Space
ds,
envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
envEnv :: r
envEnv = r
r,
envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
}
type VTable lore = M.Map VName (VarEntry lore)
data ImpState lore r op = ImpState
{ forall lore r op. ImpState lore r op -> VTable lore
stateVTable :: VTable lore,
forall lore r op. ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op,
forall lore r op. ImpState lore r op -> Code op
stateCode :: Imp.Code op,
forall lore r op. ImpState lore r op -> Warnings
stateWarnings :: Warnings,
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda lore, [SubExp])),
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
}
newState :: VNameSource -> ImpState lore r op
newState :: forall lore r op. VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a. Monoid a => a
mempty
newtype ImpM lore r op a
= ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
deriving
( (forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor,
Functor (ImpM lore r op)
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: forall a. a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
Applicative,
Applicative (ImpM lore r op)
Applicative (ImpM lore r op)
-> (forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
Monad,
MonadState (ImpState lore r op),
MonadReader (Env lore r op)
)
instance MonadFreshNames (ImpM lore r op) where
getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
instance HasScope SOACS (ImpM lore r op) where
askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall {lore}. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
where
entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
(ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
NoUniqueness
NoUniqueness
entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
entryType (AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness
runImpM ::
ImpM lore r op a ->
r ->
Operations lore r op ->
Imp.Space ->
ImpState lore r op ->
(a, ImpState lore r op)
runImpM :: forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)
subImpM_ ::
r' ->
Operations lore r' op' ->
ImpM lore r' op' a ->
ImpM lore r op (Imp.Code op')
subImpM_ :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m
subImpM ::
r' ->
Operations lore r' op' ->
ImpM lore r' op' a ->
ImpM lore r op (a, Imp.Code op')
subImpM :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get
let env' :: Env lore r' op'
env' =
Env lore r op
env
{ envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops,
envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops,
envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops,
envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops,
envEnv :: r'
envEnv = r'
r
}
s' :: ImpState lore r' op'
s' =
ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState
{ stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s,
stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s,
stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s
}
(a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'
VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
(a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'
collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
a
x <- ImpM lore r op a
m
Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
(a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
[Char]
desc ImpM lore r op ()
m = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
desc Code op
code
emit :: Imp.Code op -> ImpM lore r op ()
emit :: forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}
warnings :: Warnings -> ImpM lore r op ()
warnings :: forall lore r op. Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn loc
loc [loc]
locs [Char]
problem =
Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) ([Char] -> Doc
forall a. IsString a => [Char] -> a
fromString [Char]
problem)
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs
constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: forall lore. Mem lore => Stms lore -> VTable lore
constsVTable = (Stm lore -> Map VName (VarEntry lore))
-> Seq (Stm lore) -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Map VName (VarEntry lore)
forall {lore}.
(LetDec lore ~ LParamMem) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
where
stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
(PatElemT LParamMem -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
forall {lore}.
Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT LParamMem] -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
peVtable :: Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name LParamMem
dec) =
VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) LParamMem
dec
compileProg ::
(Mem lore, FreeIn op, MonadFreshNames m) =>
r ->
Operations lore r op ->
Imp.Space ->
Prog lore ->
m (Warnings, Imp.Definitions op)
compileProg :: forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
(VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let ([()]
_, [ImpState lore r op]
ss) =
[((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
free_in_funs :: Names
free_in_funs =
Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
(Constants op
consts', ImpState lore r op
s') =
ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
[ImpState lore r op] -> ImpState lore r op
forall {lore} {r} {op} {lore} {r}.
[ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
in ( ( ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')
),
ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s'
)
where
compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM
(FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef)
r
r
Operations lore r op
ops
Space
space
(VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) {stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts}
combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src)
{ stateFunctions :: Functions op
stateFunctions =
[(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
stateWarnings :: Warnings
stateWarnings =
[Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
}
compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
where
extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
extract (Imp.DeclareMem VName
name Space
space)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
Code op
forall a. Monoid a => a
mempty
)
extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
Code op
forall a. Monoid a => a
mempty
)
extract Code op
s =
(DList Param
forall a. Monoid a => a
mempty, Code op
s)
compileInParam ::
Mem lore =>
FParam lore ->
ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam lore
Param FParamMem
fparam of
MemPrim PrimType
bt ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
MemMem Space
space ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
MemAcc {} ->
[Char] -> ImpM lore r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
where
name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param FParamMem
fparam
data ArrayDecl = ArrayDecl VName PrimType MemLocation
compileInParams ::
Mem lore =>
[FParam lore] ->
[EntryPointType] ->
ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
let ([Param FParamMem]
ctx_params, [Param FParamMem]
val_params) =
Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param FParamMem]
params
([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM lore r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [Param FParamMem]
val_params)
let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param FParamMem]
params
where
memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
| MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
(VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
| Bool
otherwise =
Maybe (VName, Space)
forall a. Maybe a
Nothing
findMemInfo :: VName -> Maybe Space
findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries
mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
(Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
(Maybe ArrayDecl
_, Prim PrimType
bt) ->
ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
(Maybe ArrayDecl, Type)
_ ->
Maybe ValueDesc
forall a. Maybe a
Nothing
mkExts :: [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [Param FParamMem]
fparams =
let ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
in [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
[Char]
desc
((Param FParamMem -> Maybe ValueDesc)
-> [Param FParamMem] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param FParamMem -> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param FParamMem]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
[EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
rest
mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeUnsigned)
[ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeDirect)
[ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
mkExts [EntryPointType]
_ [Param FParamMem]
_ = []
([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param FParamMem]
val_params)
where
isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
compileOutParams ::
Mem lore =>
[RetType lore] ->
[EntryPointType] ->
ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
(([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination)))
-> WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [RetTypeMem]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[RetTypeMem]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
where
imp :: ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a)
-> (ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
mkExts :: [EntryPointType]
-> [RetTypeMem]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [RetTypeMem]
rts = do
let ([RetTypeMem]
rts', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rts
([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination]))
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[(ValueDesc, ValueDestination)]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (RetTypeMem
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination))
-> [RetTypeMem]
-> [Signedness]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM RetTypeMem
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam [RetTypeMem]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rest
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue [Char]
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
[ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
)
mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
(ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeUnsigned
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
)
mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
(ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeDirect
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
)
mkExts [EntryPointType]
_ [RetTypeMem]
_ = ([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])
mkParam :: RetTypeMem
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
[Char]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not explicitly return memory blocks."
mkParam MemAcc {} Signedness
_ =
[Char]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."
mkParam (MemPrim PrimType
t) Signedness
ept = do
VName
out <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall {a}.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"scalar_out"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
(ValueDesc, ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
mkParam (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
Space
space <- (Env lore r op -> Space)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
VName
memout <- case MemReturn
dec of
ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
VName
memout <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall {a}.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_mem"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
)
VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
ReturnsInBlock VName
memout ExtIxFun
_ ->
VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
[SubExp]
resultshape <- (Ext SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp)
-> [Ext SubExp]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
inspectExtSize ([Ext SubExp]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[SubExp])
-> [Ext SubExp]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
(ValueDesc, ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [SubExp]
resultshape,
Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
)
inspectExtSize :: Ext SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
inspectExtSize (Ext Int
x) = do
(Map Any Any
memseen, Map Int VName
arrseen) <- StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
Maybe VName
Nothing -> do
VName
out <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall {a}.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_arrsize"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
)
(Map Any Any, Map Int VName)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp)
-> SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
Just VName
out ->
SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp)
-> SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
inspectExtSize (Free SubExp
se) =
SubExp
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
compileFunDef ::
Mem lore =>
FunDef lore ->
ImpM lore r op ()
compileFunDef :: forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
(Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
(([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
lore
r
op
(([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
where
params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[RetTypeMem]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
[FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
[ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds
let Body BodyDec lore
_ Stms lore
stms [SubExp]
ses = BodyT lore
body
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []
([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)
compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []
compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) =
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(Param dec, SubExp)]
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [SubExp] -> [(Param dec, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [SubExp]
ses) (((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []
compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
[VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
[ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExp)]
-> ((Param dec, VName, SubExp)
-> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> [SubExp] -> [(Param dec, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [SubExp]
ses) (((Param dec, VName, SubExp) -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, SubExp)
-> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExp
se) ->
case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
Prim PrimType
pt -> do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
Mem Space
space | Var VName
v <- SubExp
se -> do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
[ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
StmsCompiler lore r op
cb <- (Env lore r op -> StmsCompiler lore r op)
-> ImpM lore r op (StmsCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> StmsCompiler lore r op
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
StmsCompiler lore r op
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m
defCompileStms ::
(Mem lore, FreeIn op) =>
Names ->
Stms lore ->
ImpM lore r op () ->
ImpM lore r op ()
defCompileStms :: forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
where
compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e : [Stm lore]
bs) = do
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat)
Code op
e_code <-
Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
(Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT LParamMem -> Set (VName, Space)
patternAllocs Pattern lore
PatternT LParamMem
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
let dies_here :: VName -> Bool
dies_here VName
v =
Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code
Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
compileStms' Set (VName, Space)
_ [] = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms
patternAllocs :: PatternT LParamMem -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT LParamMem -> [(VName, Space)])
-> PatternT LParamMem
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> Maybe (VName, Space))
-> [PatElemT LParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT LParamMem -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT LParamMem] -> [(VName, Space)])
-> (PatternT LParamMem -> [PatElemT LParamMem])
-> PatternT LParamMem
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e
defCompileExp ::
(Mem lore) =>
Pattern lore ->
Exp lore ->
ImpM lore r op ()
defCompileExp :: forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (SubExp -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp SubExp
cond) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch)
defCompileExp Pattern lore
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
[VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
[Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ImpM lore r op (Maybe Arg))
-> [(SubExp, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp, Diet) -> ImpM lore r op (Maybe Arg)
forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
where
compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
case (SubExp
se, Type
t) of
(SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
(Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
(SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op
defCompileExp Pattern lore
pat (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
body) = do
Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] [Char]
"#[unroll] on loop with unknown number of iterations."
[FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param FParamMem]
mergepat
[(Param FParamMem, SubExp)]
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param FParamMem, SubExp)]
merge (((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []
let doBody :: ImpM lore r op ()
doBody = [Param FParamMem] -> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param FParamMem]
mergepat BodyT lore
body
case LoopForm lore
form of
ForLoop VName
i IntType
_ SubExp
bound [(LParam lore, VName)]
loopvars -> do
let setLoopParam :: (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam (Param LParamMem
p, VName
a)
| Prim PrimType
_ <- Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
| Bool
otherwise =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Exp
bound' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
bound
[LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param LParamMem, VName) -> Param LParamMem)
-> [(Param LParamMem, VName)] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param LParamMem, VName) -> Param LParamMem
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound' (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
((Param LParamMem, VName) -> ImpM lore r op ())
-> [(Param LParamMem, VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
WhileLoop VName
cond ->
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody
Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
[(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([SubExp] -> [(ValueDestination, SubExp)])
-> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
merge) (((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
where
merge :: [(Param FParamMem, SubExp)]
merge = [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val
mergepat :: [Param FParamMem]
mergepat = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge
defCompileExp Pattern lore
pat (WithAcc [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs Lambda lore
lam) = do
[LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
[((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)]
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)
-> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs ([Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)])
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) ((((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)
-> ImpM lore r op ())
-> ImpM lore r op ())
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
Param LParamMem)
-> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda lore, [SubExp])
op), Param LParamMem
p) ->
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
ImpState lore r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = VName
-> ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda lore, [SubExp])
op) (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a b. (a -> b) -> a -> b
$ ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s}
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
let nonacc_res :: [SubExp]
nonacc_res = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
num_accs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))
nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nonacc_res) (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat)
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names [SubExp]
nonacc_res) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
v [] SubExp
se []
where
num_accs :: Int
num_accs = [(Shape, [VName], Maybe (Lambda lore, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs
defCompileExp Pattern lore
pat (Op Op lore
op) = do
PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
-> PatternT LParamMem -> Op lore -> ImpM lore r op ())
-> ImpM
lore r op (PatternT LParamMem -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> PatternT LParamMem -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT LParamMem
pat Op lore
op
defCompileBasicOp ::
Mem lore =>
Pattern lore ->
BasicOp ->
ImpM lore r op ()
defCompileBasicOp :: forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp SubExp
se) =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque SubExp
se) =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op SubExp
e) = do
Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp PatternT (LetDec lore)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
ErrorMsg Exp
msg' <- (SubExp -> ImpM lore r op Exp)
-> ErrorMsg SubExp -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg SubExp
msg
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc
Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
(SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> [Char] -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc [Char]
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice SubExp
slice)
| Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM lore r op ())
-> [DimIndex (TExp Int64)] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
idxs
defCompileBasicOp PatternT (LetDec lore)
_ Index {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice SubExp
slice SubExp
se) =
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice) SubExp
se
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [SubExp]
ds) SubExp
se) = do
[Exp]
ds' <- (SubExp -> ImpM lore r op Exp) -> [SubExp] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
ds
[VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) SubExp
se []
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
-> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp PatternT (LetDec lore)
_ Scratch {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
Exp
s' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
[Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
n) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
TV Any
x <-
[Char] -> TExp Any -> ImpM lore r op (TV Any)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"x" (TExp Any -> ImpM lore r op (TV Any))
-> TExp Any -> ImpM lore r op (TV Any)
forall a b. (a -> b) -> a -> b
$
Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys SubExp
_) = do
TV Int64
offs_glb <- [Char] -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0
[VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
[SubExp]
y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM lore r op Type -> ImpM lore r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
let rows :: TExp Int64
rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
[] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
y
SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
r
skip_dims :: [SubExp]
skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
skip_dims
destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [SubExp]
es Type
_)
| Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe)
Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
VName
static_array <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
"static_array"
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
let static_src :: MemLocation
static_src =
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([SubExp] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SubExp]
es) TExp Int64
1]
CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
| Bool
otherwise =
[(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
where
isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp PatternT (LetDec lore)
_ Rearrange {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Rotate {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Reshape {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = [Char] -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
sComment [Char]
"UpdateAcc" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
let is' :: Shape (TExp Int64)
is' = (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
is
(VName
_, Space
_, [VName]
arrs, Shape (TExp Int64)
dims, Maybe (Lambda lore)
op) <- VName
-> Shape (TExp Int64)
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
acc Shape (TExp Int64)
is'
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
is') Shape (TExp Int64)
dims) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
case Maybe (Lambda lore)
op of
Maybe (Lambda lore)
Nothing ->
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
v []
Just Lambda lore
lam -> do
[LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
let ([VName]
x_params, [VName]
y_params) =
Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
[(VName, VName)]
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) Shape (TExp Int64)
is'
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
yp [] SubExp
v []
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
se) ->
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
se []
defCompileBasicOp PatternT (LetDec lore)
pat BasicOp
e =
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n "
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n "
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
pretty BasicOp
e
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall {lore} {r} {op}. ArrayDecl -> ImpM lore r op ()
addArray
where
addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
Maybe (Exp lore)
forall a. Maybe a
Nothing
ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
{ entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams = (Param FParamMem -> ImpM lore r op ())
-> [Param FParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM lore r op ()
forall {u} {lore} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam
where
addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dVars ::
Mem lore =>
Maybe (Exp lore) ->
[PatElem lore] ->
ImpM lore r op ()
dVars :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT LParamMem -> ImpM lore r op ())
-> [PatElemT LParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT LParamMem -> ImpM lore r op ()
dVar
where
dVar :: PatElemT LParamMem -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT LParamMem -> Scope lore)
-> PatElemT LParamMem
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT LParamMem -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem
dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param FParamMem] -> Scope lore)
-> [Param FParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams
dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param LParamMem] -> Scope lore)
-> [Param LParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams
dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimVol :: forall t lore r op.
[Char] -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim :: forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name PrimType
t = do
VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrimV_ :: VName -> Imp.TExp t -> ImpM lore r op ()
dPrimV_ :: forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
name TExp t
e = do
VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t
VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
where
t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
dPrimV :: String -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimV :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
name TExp t
e = do
TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'
dPrimVE :: String -> Imp.TExp t -> ImpM lore r op (Imp.TExp t)
dPrimVE :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
TExp t -> ImpM lore r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'
memBoundToVarEntry ::
Maybe (Exp lore) ->
MemBound NoUniqueness ->
VarEntry lore
memBoundToVarEntry :: forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
forall lore.
Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
AccVar Maybe (Exp lore)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
let location :: MemLocation
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
Maybe (Exp lore)
e
ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
{ entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
infoDec ::
Mem lore =>
NameInfo lore ->
MemInfo SubExp NoUniqueness MemBind
infoDec :: forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec (LetName LetDec lore
dec) = LetDec lore
LParamMem
dec
infoDec (FParamName FParamInfo lore
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
FParamMem
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dInfo ::
Mem lore =>
Maybe (Exp lore) ->
VName ->
NameInfo lore ->
ImpM lore r op ()
dInfo :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
let entry :: VarEntry lore
entry = Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> LParamMem
forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec NameInfo lore
info
case VarEntry lore
entry of
MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
AccVar {} ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry
dScope ::
Mem lore =>
Maybe (Exp lore) ->
Scope lore ->
ImpM lore r op ()
dScope :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Map VName (NameInfo lore) -> [(VName, NameInfo lore)])
-> Map VName (NameInfo lore)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo lore) -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList
dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt Shape
shape MemBind
membind =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
NoUniqueness MemBind
membind
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
[[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall {m :: * -> *}. Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
where
funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
funcallTarget (ArrayDestination Maybe MemLocation
_) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
funcallTarget (MemoryDestination VName
name) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
data TV t = TV VName PrimType
mkTV :: VName -> PrimType -> TV t
mkTV :: forall t. VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV
tvSize :: TV t -> Imp.DimSize
tvSize :: forall t. TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar
tvExp :: TV t -> Imp.TExp t
tvExp :: forall t. TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t ExpLeaf) -> Exp -> TPrimExp t ExpLeaf
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
tvVar :: TV t -> VName
tvVar :: forall t. TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v
class ToExp a where
toExp :: a -> ImpM lore r op Imp.Exp
toExp' :: PrimType -> a -> Imp.Exp
toInt64Exp :: a -> Imp.TExp Int64
toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64
toBoolExp :: a -> Imp.TExp Bool
toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool
instance ToExp SubExp where
toExp :: forall lore r op. SubExp -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp (Var VName
v) =
VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
VarEntry lore
_ -> [Char] -> ImpM lore r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op Exp) -> [Char] -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v
toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
instance ToExp (PrimExp VName) where
toExp :: forall lore r op. PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s}
localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})
askFunction :: ImpM lore r op (Maybe Name)
askFunction :: forall lore r op. ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
s = do
Maybe [Char]
fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
[Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS -> Maybe [Char] -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s
nameForFun :: String -> ImpM lore r op Name
nameForFun :: forall lore r op. [Char] -> ImpM lore r op Name
nameForFun [Char]
s = do
Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s
askEnv :: ImpM lore r op r
askEnv :: forall lore r op. ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: forall r lore op a.
(r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env}
askAttrs :: ImpM lore r op Attrs
askAttrs :: forall lore r op. ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env}
localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
Env lore r op
env
{ envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
}
getVTable :: ImpM lore r op (VTable lore)
getVTable :: forall lore r op. ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VTable lore
vtable}
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
a
a <- ImpM lore r op a
m
VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
case Maybe (VarEntry lore)
res of
Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
Maybe (VarEntry lore)
_ -> [Char] -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (VarEntry lore))
-> [Char] -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
res of
ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
VarEntry lore
_ -> [Char] -> ImpM lore r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ArrayEntry)
-> [Char] -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
res of
MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
VarEntry lore
_ -> [Char] -> ImpM lore r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op MemEntry)
-> [Char] -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
lookupArraySpace :: VName -> ImpM lore r op Space
lookupArraySpace :: forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace =
(MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM lore r op MemEntry -> ImpM lore r op Space)
-> (VName -> ImpM lore r op MemEntry)
-> VName
-> ImpM lore r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory
(VName -> ImpM lore r op Space)
-> (VName -> ImpM lore r op VName) -> VName -> ImpM lore r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM lore r op ArrayEntry -> ImpM lore r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation) (ImpM lore r op ArrayEntry -> ImpM lore r op VName)
-> (VName -> ImpM lore r op ArrayEntry)
-> VName
-> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray
lookupAcc ::
VName ->
[Imp.TExp Int64] ->
ImpM lore r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda lore))
lookupAcc :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
name Shape (TExp Int64)
is = do
VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
res of
AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' <- (ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp]))))
-> (ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> (ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs
case Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' of
Just ([], Maybe (Lambda lore, [SubExp])
_) ->
[Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda lore
op, [SubExp]
_)) -> do
Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
let ([Param (LParamInfo lore)]
i_params, [Param (LParamInfo lore)]
ps) = Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
is) ([Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
op
(VName -> TExp Int64 -> ImpM lore r op ())
-> [VName] -> Shape (TExp Int64) -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
i_params) Shape (TExp Int64)
is
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName
acc,
Space
space,
[VName]
arrs,
(SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
Lambda lore -> Maybe (Lambda lore)
forall a. a -> Maybe a
Just Lambda lore
op {lambdaParams :: [Param (LParamInfo lore)]
lambdaParams = [Param (LParamInfo lore)]
ps}
)
Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda lore, [SubExp])
Nothing) -> do
Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, Space
space, [VName]
arrs, (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), Maybe (Lambda lore)
forall a. Maybe a
Nothing)
Maybe ([VName], Maybe (Lambda lore, [SubExp]))
Nothing ->
[Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
VarEntry lore
_ -> [Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
lore
r
op
(VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT LParamMem] -> ImpM lore r op [ValueDestination])
-> [PatElemT LParamMem]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> ImpM lore r op ValueDestination)
-> [PatElemT LParamMem] -> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LParamMem -> ImpM lore r op ValueDestination
forall {dec} {lore} {r} {op}.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT LParamMem] -> ImpM lore r op Destination)
-> [PatElemT LParamMem] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
where
inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
entry of
ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation {} PrimType
_) ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
MemVar {} ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
ScalarVar {} ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
AccVar {} ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
fullyIndexArray ::
VName ->
[Imp.TExp Int64] ->
ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices
fullyIndexArray' ::
MemLocation ->
[Imp.TExp Int64] ->
ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
let indices' :: Shape (TExp Int64)
indices' = case Space
space of
ScalarSpace [SubExp]
ds PrimType
_ ->
let (Shape (TExp Int64)
zero_is, Shape (TExp Int64)
is) = Int
-> Shape (TExp Int64) -> (Shape (TExp Int64), Shape (TExp Int64))
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Shape (TExp Int64)
indices
in (TExp Int64 -> TExp Int64)
-> Shape (TExp Int64) -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> TExp Int64
forall a b. a -> b -> a
const TExp Int64
0) Shape (TExp Int64)
zero_is Shape (TExp Int64) -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a] -> [a]
++ Shape (TExp Int64)
is
Space
_ -> Shape (TExp Int64)
indices
(VName, Space, Count Elements (TExp Int64))
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName
mem,
Space
space,
TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices'
)
copy :: CopyCompiler lore r op
copy :: forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
isMapTransposeCopy ::
PrimType ->
MemLocation ->
Slice (Imp.TExp Int64) ->
MemLocation ->
Slice (Imp.TExp Int64) ->
Maybe
( Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64
)
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
PrimType
bt
(MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun)
[DimIndex (TExp Int64)]
destslice
(MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
[DimIndex (TExp Int64)]
srcslice
| Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
{b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
{b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Bool
otherwise =
Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
where
bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)
destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice
isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
(a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
( a
dest_offset,
b
src_offset,
c
num_arrays,
d
size_x,
e
size_y
)
getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
(t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)
mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
bt
mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType :: forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
bt = do
let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt
Bool
exists <- Name -> ImpM lore r op Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt
Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname
defaultCopy :: CopyCompiler lore r op
defaultCopy :: forall lore r op. CopyCompiler lore r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
| Just
( TExp Int64
destoffset,
TExp Int64
srcoffset,
TExp Int64
num_arrays,
TExp Int64
size_x,
TExp Int64
size_y
) <-
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
Name
fname <- PrimType -> ImpM lore r op Name
forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
pt
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
[]
Name
fname
([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
PrimType
pt
VName
destmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
VName
srcmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
TExp Int64
num_arrays
TExp Int64
size_x
TExp Int64
size_y
| Just TExp Int64
destoffset <-
IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
Just TExp Int64
srcoffset <-
IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
else
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
VName
destmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
Space
destspace
VName
srcmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
Space
srcspace
(Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
| Bool
otherwise =
CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
where
pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
MemLocation VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
MemLocation VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src
isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
isScalarSpace Space
_ = Bool
False
copyElementWise :: CopyCompiler lore r op
copyElementWise :: forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
[VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
(VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
((Code op -> Code op)
-> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol
copyArrayDWIM ::
PrimType ->
MemLocation ->
[DimIndex (Imp.TExp Int64)] ->
MemLocation ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op (Imp.Code op)
copyArrayDWIM :: forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM
PrimType
bt
destlocation :: MemLocation
destlocation@(MemLocation VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
destslice
srclocation :: MemLocation
srclocation@(MemLocation VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
srcslice
| Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
(VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
| Bool
otherwise = do
let destslice' :: [DimIndex (TExp Int64)]
destslice' =
Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
then
[Char] -> ImpM lore r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (Code op))
-> [Char] -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
[Char]
"copyArrayDWIM: cannot copy to "
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
destrank
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
srcrank
[Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
else
if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty
else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'
copyDWIMDest ::
ValueDestination ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op ()
copyDWIMDest :: forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
Maybe (Shape (TExp Int64))
Nothing ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"with slice destination."]
Just Shape (TExp Int64)
dest_is ->
case ValueDestination
pat of
ScalarDestination VName
name ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
MemoryDestination {} ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be written to memory destination."]
ArrayDestination (Just MemLocation
dest_loc) -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
ArrayDestination Maybe MemLocation
Nothing ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
where
bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
case (ValueDestination
dest, VarEntry lore
src_entry) of
(MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
(MemoryDestination {}, VarEntry lore
_) ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"to memory destination."]
(ValueDestination
_, MemVar {}) ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"is a memory block."]
(ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_))
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice]
(ScalarDestination VName
name, VarEntry lore
_)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice]
(ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
(ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
| Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
[DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
(VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
| Bool
otherwise ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords
[ [Char]
"copyDWIMDest: prim-typed target",
VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name,
[Char]
"and array-typed source",
VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
[Char]
"with slice",
[DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice
]
(ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
(ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
| Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
| Bool
otherwise ->
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords
[ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
[Char]
"with slice",
[DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice
]
(ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(ValueDestination
_, AccVar {}) ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
copyDWIM ::
VName ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op ()
copyDWIM :: forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
let dest_target :: ValueDestination
dest_target =
case VarEntry lore
dest_entry of
ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
VName -> ValueDestination
ScalarDestination VName
dest
ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
MemVar Maybe (Exp lore)
_ MemEntry
_ ->
VName -> ValueDestination
MemoryDestination VName
dest
AccVar {} ->
Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice
copyDWIMFix ::
VName ->
[Imp.TExp Int64] ->
SubExp ->
[Imp.TExp Int64] ->
ImpM lore r op ()
copyDWIMFix :: forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is SubExp
src Shape (TExp Int64)
src_is =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)
compileAlloc ::
Mem lore =>
Pattern lore ->
SubExp ->
Space ->
ImpM lore r op ()
compileAlloc :: forall lore r op.
Mem lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) SubExp
e Space
space = do
let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
e
Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
case Maybe (AllocCompiler lore r op)
allocator of
Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e' Space
space
Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e'
compileAlloc PatternT (LetDec lore)
pat SubExp
_ Space
_ =
[Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims =
let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims
sFor' :: VName -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound ImpM lore r op ()
body = do
let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
IntType IntType
bound_t -> IntType
bound_t
PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
pretty Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
t
VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'
sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM lore r op ()
body = do
VName
i' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
TExp t -> ImpM lore r op ()
body (TExp t -> ImpM lore r op ()) -> TExp t -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound
sWhile :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
cond ImpM lore r op ()
body = do
Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'
sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
[Char]
s ImpM lore r op ()
code = do
Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
s Code op
code'
sIf :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'
sWhen :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
cond ImpM lore r op ()
tbranch = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
sUnless :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
sOp :: op -> ImpM lore r op ()
sOp :: forall op lore r. op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space = do
VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ :: forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
case Maybe (AllocCompiler lore r op)
allocator of
Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes (TExp Int64)
size'
sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op VName
sAlloc :: forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
VName
name' <- [Char] -> Space -> ImpM lore r op VName
forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
bt Shape
shape MemBind
membind = do
VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt Shape
shape MemBind
membind
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: forall lore r op.
[Char] -> PrimType -> Shape -> VName -> ImpM lore r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
VName
mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: forall lore r op.
[Char] -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: forall lore r op.
[Char]
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray [Char]
name Space
space PrimType
pt ArrayContents
vs = do
let num_elems :: Int
num_elems = case ArrayContents
vs of
Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
VName
mem <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]
sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM lore r op ()
sWrite :: forall lore r op.
VName -> Shape (TExp Int64) -> Exp -> ImpM lore r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
(VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v
sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate :: forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice SubExp
v []
sLoopNest ::
Shape ->
([Imp.TExp Int64] -> ImpM lore r op ()) ->
ImpM lore r op ()
sLoopNest :: forall lore r op.
Shape
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = Shape (TExp Int64)
-> [SubExp]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall {a} {lore} {r} {op}.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' [] ([SubExp]
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Shape -> [SubExp])
-> Shape
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims
where
sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM lore r op ()
f = Shape (TExp Int64) -> ImpM lore r op ()
f (Shape (TExp Int64) -> ImpM lore r op ())
-> Shape (TExp Int64) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM lore r op ()
f =
[Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM lore r op ()
f
(<~~) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <~~ :: forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e
infixl 3 <~~
(<--) :: TV t -> Imp.TExp t -> ImpM lore r op ()
TV VName
x PrimType
_ <-- :: forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
infixl 3 <--
function ::
Name ->
[Imp.Param] ->
[Imp.Param] ->
ImpM lore r op () ->
ImpM lore r op ()
function :: forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
(Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall {lore} {r} {op}. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
ImpM lore r op ()
m
Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
where
addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
addParam (Imp.ScalarParam VName
name PrimType
bt) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}