{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.KernelsMem
( KernelsMem,
simplifyProg,
simplifyStms,
simpleKernelsMem,
module Futhark.IR.Mem,
module Futhark.IR.Kernels.Kernel,
)
where
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Kernels.Kernel
import Futhark.IR.Kernels.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC
data KernelsMem
instance Decorations KernelsMem where
type LetDec KernelsMem = LetDecMem
type FParamInfo KernelsMem = FParamMem
type LParamInfo KernelsMem = LParamMem
type RetType KernelsMem = RetTypeMem
type BranchType KernelsMem = BranchTypeMem
type Op KernelsMem = MemOp (HostOp KernelsMem ())
instance ASTLore KernelsMem where
expTypesFromPattern :: Pattern KernelsMem -> m [BranchType KernelsMem]
expTypesFromPattern = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [BodyReturns])
-> PatternT (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> PatternT (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BodyReturns)], [(VName, BodyReturns)])
-> [(VName, BodyReturns)]
forall a b. (a, b) -> b
snd (([(VName, BodyReturns)], [(VName, BodyReturns)])
-> [(VName, BodyReturns)])
-> (PatternT (MemBound NoUniqueness)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT (MemBound NoUniqueness)
-> [(VName, BodyReturns)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern
instance OpReturns KernelsMem where
opReturns :: Op KernelsMem -> m [ExpReturns]
opReturns (Alloc _ space) =
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
opReturns (Inner (SegOp op)) = SegOp SegLevel KernelsMem -> m [ExpReturns]
forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns SegOp SegLevel KernelsMem
op
opReturns Op KernelsMem
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp (HostOp KernelsMem ()) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op KernelsMem
MemOp (HostOp KernelsMem ())
k
instance PrettyLore KernelsMem
instance TC.CheckableOp KernelsMem where
checkOp :: OpWithAliases (Op KernelsMem) -> TypeM KernelsMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases KernelsMem) ()) -> TypeM KernelsMem ()
forall lore b.
(Checkable lore,
OpWithAliases (Op lore) ~ MemOp (HostOp (Aliases lore) b)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
where
typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases lore) b
op) =
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (b -> TypeM lore ())
-> HostOp (Aliases lore) b
-> TypeM lore ()
forall lore op.
Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp (Maybe SegLevel
-> MemOp (HostOp (Aliases lore) b) -> TypeM lore ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases lore) b)
-> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM lore () -> b -> TypeM lore ()
forall a b. a -> b -> a
const (TypeM lore () -> b -> TypeM lore ())
-> TypeM lore () -> b -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases lore) b
op
instance TC.Checkable KernelsMem where
checkFParamLore :: VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
checkFParamLore = VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkLParamLore :: VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
checkLParamLore = VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkLetBoundLore :: VName -> LetDec KernelsMem -> TypeM KernelsMem ()
checkLetBoundLore = VName -> LetDec KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkRetType :: [RetType KernelsMem] -> TypeM KernelsMem ()
checkRetType = (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem] -> TypeM KernelsMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem] -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem]
-> TypeM KernelsMem ()
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape Uniqueness -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
TypeBase ExtShape u -> TypeM lore ()
TC.checkExtType (TypeBase ExtShape Uniqueness -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeBase ExtShape Uniqueness)
-> RetTypeMem
-> TypeM KernelsMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf
primFParam :: VName -> PrimType -> TypeM KernelsMem (FParam (Aliases KernelsMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
matchPattern :: Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
matchPattern = Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
matchReturnType :: [RetType KernelsMem] -> [SubExp] -> TypeM KernelsMem ()
matchReturnType = [RetType KernelsMem] -> [SubExp] -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[RetTypeMem] -> [SubExp] -> TypeM lore ()
matchFunctionReturnType
matchBranchType :: [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
matchBranchType = [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[BodyReturns] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType
matchLoopResult :: [FParam (Aliases KernelsMem)]
-> [FParam (Aliases KernelsMem)] -> [SubExp] -> TypeM KernelsMem ()
matchLoopResult = [FParam (Aliases KernelsMem)]
-> [FParam (Aliases KernelsMem)] -> [SubExp] -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[FParam (Aliases lore)]
-> [FParam (Aliases lore)] -> [SubExp] -> TypeM lore ()
matchLoopResultMem
instance BinderOps KernelsMem where
mkExpDecB :: Pattern KernelsMem -> Exp KernelsMem -> m (ExpDec KernelsMem)
mkExpDecB Pattern KernelsMem
_ Exp KernelsMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
mkBodyB :: Stms KernelsMem -> [SubExp] -> m (Body KernelsMem)
mkBodyB Stms KernelsMem
stms [SubExp]
res = Body KernelsMem -> m (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> m (Body KernelsMem))
-> Body KernelsMem -> m (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem
-> Stms KernelsMem -> [SubExp] -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () Stms KernelsMem
stms [SubExp]
res
mkLetNamesB :: [VName] -> Exp KernelsMem -> m (Stm KernelsMem)
mkLetNamesB = ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()
instance BinderOps (Engine.Wise KernelsMem) where
mkExpDecB :: Pattern (Wise KernelsMem)
-> Exp (Wise KernelsMem) -> m (ExpDec (Wise KernelsMem))
mkExpDecB Pattern (Wise KernelsMem)
pat Exp (Wise KernelsMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise KernelsMem)
-> ExpDec KernelsMem
-> Exp (Wise KernelsMem)
-> ExpDec (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise KernelsMem)
pat () Exp (Wise KernelsMem)
e
mkBodyB :: Stms (Wise KernelsMem) -> [SubExp] -> m (Body (Wise KernelsMem))
mkBodyB Stms (Wise KernelsMem)
stms [SubExp]
res = Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise KernelsMem) -> m (Body (Wise KernelsMem)))
-> Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem
-> Stms (Wise KernelsMem) -> [SubExp] -> Body (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise KernelsMem)
stms [SubExp]
res
mkLetNamesB :: [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
mkLetNamesB = [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''
simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg = SimpleOps KernelsMem -> Prog KernelsMem -> PassM (Prog KernelsMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric SimpleOps KernelsMem
simpleKernelsMem
simplifyStms ::
(HasScope KernelsMem m, MonadFreshNames m) =>
Stms KernelsMem ->
m
( Engine.SymbolTable (Engine.Wise KernelsMem),
Stms KernelsMem
)
simplifyStms :: Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms = SimpleOps KernelsMem
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall lore (m :: * -> *) inner.
(HasScope lore m, MonadFreshNames m, SimplifyMemory lore,
Op lore ~ MemOp inner) =>
SimpleOps lore
-> Stms lore -> m (SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric SimpleOps KernelsMem
simpleKernelsMem
simpleKernelsMem :: Engine.SimpleOps KernelsMem
simpleKernelsMem :: SimpleOps KernelsMem
simpleKernelsMem =
(OpWithWisdom (HostOp KernelsMem ()) -> UsageTable)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
simpleGeneric OpWithWisdom (HostOp KernelsMem ()) -> UsageTable
HostOp (Wise KernelsMem) () -> UsageTable
usage (SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
KernelsMem
(HostOp (Wise KernelsMem) (OpWithWisdom ()),
Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
KernelsMem
(HostOp (Wise KernelsMem) (OpWithWisdom ()),
Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
KernelsMem
(HostOp (Wise KernelsMem) (OpWithWisdom ()),
Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)
where
usage :: HostOp (Wise KernelsMem) () -> UsageTable
usage (SegOp (SegMap SegGroup {} SegSpace
_ [Type]
_ KernelBody (Wise KernelsMem)
kbody)) = KernelBody (Wise KernelsMem) -> UsageTable
localAllocs KernelBody (Wise KernelsMem)
kbody
usage HostOp (Wise KernelsMem) ()
_ = UsageTable
forall a. Monoid a => a
mempty
localAllocs :: KernelBody (Wise KernelsMem) -> UsageTable
localAllocs = (Stm (Wise KernelsMem) -> UsageTable)
-> Stms (Wise KernelsMem) -> UsageTable
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise KernelsMem) -> UsageTable
stmLocalAlloc (Stms (Wise KernelsMem) -> UsageTable)
-> (KernelBody (Wise KernelsMem) -> Stms (Wise KernelsMem))
-> KernelBody (Wise KernelsMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise KernelsMem) -> Stms (Wise KernelsMem)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms
stmLocalAlloc :: Stm (Wise KernelsMem) -> UsageTable
stmLocalAlloc = Exp (Wise KernelsMem) -> UsageTable
forall lore inner.
(Op lore ~ MemOp inner) =>
ExpT lore -> UsageTable
expLocalAlloc (Exp (Wise KernelsMem) -> UsageTable)
-> (Stm (Wise KernelsMem) -> Exp (Wise KernelsMem))
-> Stm (Wise KernelsMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise KernelsMem) -> Exp (Wise KernelsMem)
forall lore. Stm lore -> Exp lore
stmExp
expLocalAlloc :: ExpT lore -> UsageTable
expLocalAlloc (Op (Alloc (Var v) (Space "local"))) =
VName -> UsageTable
UT.sizeUsage VName
v
expLocalAlloc ExpT lore
_ =
UsageTable
forall a. Monoid a => a
mempty