{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SegOp
( SegOp (..),
SegVirt (..),
segLevel,
segSpace,
typeCheckSegOp,
SegSpace (..),
scopeOfSegSpace,
segSpaceDims,
HistOp (..),
histType,
SegBinOp (..),
segBinOpResults,
segBinOpChunks,
KernelBody (..),
aliasAnalyseKernelBody,
consumedInKernelBody,
ResultManifest (..),
KernelResult (..),
kernelResultSubExp,
SplitOrdering (..),
SegOpMapper (..),
identitySegOpMapper,
mapSegOpM,
simplifySegOp,
HasSegOp (..),
segOpRules,
segOpReturns,
)
where
import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.List
( elemIndex,
foldl',
groupBy,
intersperse,
isPrefixOf,
partition,
)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
removeLambdaAliases,
removeStmAliases,
)
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
( Pretty,
commasep,
parens,
ppr,
text,
(<+>),
(</>),
)
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
data SplitOrdering
= SplitContiguous
| SplitStrided SubExp
deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)
instance FreeIn SplitOrdering where
freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride
instance Substitute SplitOrdering where
substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
SplitOrdering
SplitContiguous
substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride
instance Rename SplitOrdering where
rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
rename (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride
data HistOp lore = HistOp
{ forall lore. HistOp lore -> SubExp
histWidth :: SubExp,
forall lore. HistOp lore -> SubExp
histRaceFactor :: SubExp,
forall lore. HistOp lore -> [VName]
histDest :: [VName],
forall lore. HistOp lore -> [SubExp]
histNeutral :: [SubExp],
forall lore. HistOp lore -> ShapeBase SubExp
histShape :: Shape,
forall lore. HistOp lore -> Lambda lore
histOp :: Lambda lore
}
deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)
histType :: HistOp lore -> [Type]
histType :: forall lore. HistOp lore -> [Type]
histType HistOp lore
op =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map
( (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op)
(Type -> Type) -> (Type -> Type) -> Type -> Type
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> ShapeBase SubExp -> Type
`arrayOfShape` HistOp lore -> ShapeBase SubExp
forall lore. HistOp lore -> ShapeBase SubExp
histShape HistOp lore
op)
)
([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op
data SegBinOp lore = SegBinOp
{ forall lore. SegBinOp lore -> Commutativity
segBinOpComm :: Commutativity,
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda :: Lambda lore,
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral :: [SubExp],
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape :: Shape
}
deriving (SegBinOp lore -> SegBinOp lore -> Bool
(SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool) -> Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp lore -> SegBinOp lore -> Bool
$c/= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
== :: SegBinOp lore -> SegBinOp lore -> Bool
$c== :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
Eq, Eq (SegBinOp lore)
Eq (SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> Ordering)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> Ord (SegBinOp lore)
SegBinOp lore -> SegBinOp lore -> Bool
SegBinOp lore -> SegBinOp lore -> Ordering
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
min :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmin :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
max :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmax :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
>= :: SegBinOp lore -> SegBinOp lore -> Bool
$c>= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
> :: SegBinOp lore -> SegBinOp lore -> Bool
$c> :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
<= :: SegBinOp lore -> SegBinOp lore -> Bool
$c<= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
< :: SegBinOp lore -> SegBinOp lore -> Bool
$c< :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
compare :: SegBinOp lore -> SegBinOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
Ord, Int -> SegBinOp lore -> ShowS
[SegBinOp lore] -> ShowS
SegBinOp lore -> String
(Int -> SegBinOp lore -> ShowS)
-> (SegBinOp lore -> String)
-> ([SegBinOp lore] -> ShowS)
-> Show (SegBinOp lore)
forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
forall lore. Decorations lore => [SegBinOp lore] -> ShowS
forall lore. Decorations lore => SegBinOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SegBinOp lore] -> ShowS
show :: SegBinOp lore -> String
$cshow :: forall lore. Decorations lore => SegBinOp lore -> String
showsPrec :: Int -> SegBinOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
Show)
segBinOpResults :: [SegBinOp lore] -> Int
segBinOpResults :: forall lore. [SegBinOp lore] -> Int
segBinOpResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)
segBinOpChunks :: [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks :: forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> [a] -> [[a]]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)
data KernelBody lore = KernelBody
{ forall lore. KernelBody lore -> BodyDec lore
kernelBodyLore :: BodyDec lore,
forall lore. KernelBody lore -> Stms lore
kernelBodyStms :: Stms lore,
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance Decorations lore => Ord (KernelBody lore)
deriving instance Decorations lore => Show (KernelBody lore)
deriving instance Decorations lore => Eq (KernelBody lore)
data ResultManifest
=
ResultNoSimplify
|
ResultMaySimplify
|
ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
Ord)
data KernelResult
=
Returns ResultManifest SubExp
| WriteReturns
Shape
VName
[(Slice SubExp, SubExp)]
|
ConcatReturns
SplitOrdering
SubExp
SubExp
VName
| TileReturns
[(SubExp, SubExp)]
VName
| RegTileReturns
[ ( SubExp,
SubExp,
SubExp
)
]
VName
deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
Ord)
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ SubExp
what) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = ShapeBase SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
freeIn' (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (TileReturns [(SubExp, SubExp)]
dims VName
v) =
[(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
[(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
instance ASTLore lore => FreeIn (KernelBody lore) where
freeIn' :: KernelBody lore -> FV
freeIn' (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec lore
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms lore -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms lore
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where
bound_in_stms :: Names
bound_in_stms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall lore. Stm lore -> Names
boundByStm Stms lore
stms
instance ASTLore lore => Substitute (KernelBody lore) where
substituteNames :: Map VName VName -> KernelBody lore -> KernelBody lore
substituteNames Map VName VName
subst (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody
(Map VName VName -> BodyDec lore -> BodyDec lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec lore
dec)
(Map VName VName -> Stms lore -> Stms lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms lore
stms)
(Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest SubExp
se) =
ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
ShapeBase SubExp
-> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Map VName VName -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
substituteNames Map VName VName
subst (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (TileReturns [(SubExp, SubExp)]
dims VName
v) =
[(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
[(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance ASTLore lore => Rename (KernelBody lore) where
rename :: KernelBody lore -> RenameM (KernelBody lore)
rename (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) = do
BodyDec lore
dec' <- BodyDec lore -> RenameM (BodyDec lore)
forall a. Rename a => a -> RenameM a
rename BodyDec lore
dec
Stms lore
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall lore a.
Renameable lore =>
Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms Stms lore
stms ((Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore))
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ \Stms lore
stms' ->
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' ([KernelResult] -> KernelBody lore)
-> RenameM [KernelResult] -> RenameM (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody ::
( ASTLore lore,
CanBeAliased (Op lore)
) =>
AliasTable ->
KernelBody lore ->
KernelBody (Aliases lore)
aliasAnalyseKernelBody :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
let Body BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [SubExp]
_ = AliasTable -> Body lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
Alias.analyseBody AliasTable
aliases (Body lore -> BodyT (Aliases lore))
-> Body lore -> BodyT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []
in BodyDec (Aliases lore)
-> Stms (Aliases lore)
-> [KernelResult]
-> KernelBody (Aliases lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [KernelResult]
res
removeKernelBodyAliases ::
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) ->
KernelBody lore
removeKernelBodyAliases :: forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases (KernelBody (BodyAliasing
_, BodyDec lore
dec) Stms (Aliases lore)
stms [KernelResult]
res) =
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
stms) [KernelResult]
res
removeKernelBodyWisdom ::
CanBeWise (Op lore) =>
KernelBody (Wise lore) ->
KernelBody lore
removeKernelBodyWisdom :: forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom (KernelBody BodyDec (Wise lore)
dec Stms (Wise lore)
stms [KernelResult]
res) =
let Body BodyDec lore
dec' Stms lore
stms' [SubExp]
_ = Body (Wise lore) -> BodyT lore
forall lore. CanBeWise (Op lore) => Body (Wise lore) -> Body lore
removeBodyWisdom (Body (Wise lore) -> BodyT lore) -> Body (Wise lore) -> BodyT lore
forall a b. (a -> b) -> a -> b
$ BodyDec (Wise lore)
-> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec (Wise lore)
dec Stms (Wise lore)
stms []
in BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' [KernelResult]
res
consumedInKernelBody ::
Aliased lore =>
KernelBody lore ->
Names
consumedInKernelBody :: forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where
consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty
checkKernelBody ::
TC.Checkable lore =>
[Type] ->
KernelBody (Aliases lore) ->
TC.TypeM lore ()
checkKernelBody :: forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec lore
dec) Stms (Aliases lore)
stms [KernelResult]
kres) = do
BodyDec lore -> TypeM lore ()
forall lore. Checkable lore => BodyDec lore -> TypeM lore ()
TC.checkBodyLore BodyDec lore
dec
(KernelResult -> TypeM lore ()) -> [KernelResult] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM lore ()
forall {lore}. Checkable lore => KernelResult -> TypeM lore ()
consumeKernelResult [KernelResult]
kres
Stms (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Stms (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.checkStms Stms (Aliases lore)
stms (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Kernel return type is " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but body returns "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
(KernelResult -> Type -> TypeM lore ())
-> [KernelResult] -> [Type] -> TypeM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM lore ()
forall {lore}.
Checkable lore =>
KernelResult -> Type -> TypeM lore ()
checkKernelResult [KernelResult]
kres [Type]
ts
where
consumeKernelResult :: KernelResult -> TypeM lore ()
consumeKernelResult (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
arr
consumeKernelResult KernelResult
_ =
() -> TypeM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkKernelResult :: KernelResult -> Type -> TypeM lore ()
checkKernelResult (Returns ResultManifest
_ SubExp
what) Type
t =
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
(DimIndex SubExp -> TypeM lore (DimIndex ()))
-> Slice SubExp -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SubExp -> TypeM lore ())
-> DimIndex SubExp -> TypeM lore (DimIndex ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> TypeM lore ())
-> DimIndex SubExp -> TypeM lore (DimIndex ()))
-> (SubExp -> TypeM lore ())
-> DimIndex SubExp
-> TypeM lore (DimIndex ())
forall a b. (a -> b) -> a -> b
$ [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
e
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"WriteReturns returning "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape="
String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShapeBase SubExp -> String
forall a. Pretty a => a -> String
pretty ShapeBase SubExp
shape
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but destination array has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
checkKernelResult (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
per_thread_elems
Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 Type
vt) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
checkKernelResult (TileReturns [(SubExp, SubExp)]
dims VName
v) Type
t = do
[(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
checkKernelResult (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles
Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> (String -> ErrorCase lore) -> String -> TypeM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> TypeM lore ()) -> String -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String
"Invalid type for TileReturns. Expected:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
expected
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\ngot:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
where
([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
reg_tiles)
kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics :: forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics = (Stm lore -> MetricsM ()) -> Seq (Stm lore) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Stm lore -> MetricsM ()
stmMetrics (Seq (Stm lore) -> MetricsM ())
-> (KernelBody lore -> Seq (Stm lore))
-> KernelBody lore
-> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms
instance PrettyLore lore => Pretty (KernelBody lore) where
ppr :: KernelBody lore -> Doc
ppr (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) =
[Doc] -> Doc
PP.stack ((Stm lore -> Doc) -> [Stm lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms))
Doc -> Doc -> Doc
</> String -> Doc
text String
"return" Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)
instance Pretty KernelResult where
ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify SubExp
what) =
String -> Doc
text String
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (Returns ResultManifest
ResultPrivate SubExp
what) =
String -> Doc
text String
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (Returns ResultManifest
ResultMaySimplify SubExp
what) =
String -> Doc
text String
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (WriteReturns ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape
Doc -> Doc -> Doc
</> String -> Doc
text String
"with" Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply (((Slice SubExp, SubExp) -> Doc)
-> [(Slice SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc
forall {a} {a}. (Pretty a, Pretty a) => ([a], a) -> Doc
ppRes [(Slice SubExp, SubExp)]
res)
where
ppRes :: ([a], a) -> Doc
ppRes ([a]
slice, a
e) =
Doc -> Doc
PP.brackets ([Doc] -> Doc
commasep ((a -> Doc) -> [a] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map a -> Doc
forall a. Pretty a => a -> Doc
ppr [a]
slice)) Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
e
ppr (ConcatReturns SplitOrdering
SplitContiguous SubExp
w SubExp
per_thread_elems VName
v) =
String -> Doc
text String
"concat"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
ppr (ConcatReturns (SplitStrided SubExp
stride) SubExp
w SubExp
per_thread_elems VName
v) =
String -> Doc
text String
"concat_strided"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
ppr (TileReturns [(SubExp, SubExp)]
dims VName
v) =
Doc
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall {a} {a}. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
where
onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile
ppr (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Doc
"blkreg_tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp) -> Doc)
-> [(SubExp, SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc
forall {a} {a} {a}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
where
onDim :: (a, a, a) -> Doc
onDim (a
dim, a
blk_tile, a
reg_tile) =
a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> Doc -> Doc
parens (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
blk_tile Doc -> Doc -> Doc
<+> Doc
"*" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
reg_tile)
data SegVirt
= SegVirt
| SegNoVirt
|
SegNoVirtFull
deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)
data SegSpace = SegSpace
{
SegSpace -> VName
segFlat :: VName,
SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace :: forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
[(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Map VName (NameInfo lore))
-> [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64
checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore ()
checkSegSpace :: forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
((VName, SubExp) -> TypeM lore ())
-> [(VName, SubExp)] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lvl lore
= SegMap lvl SegSpace [Type] (KernelBody lore)
|
SegRed lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
| SegScan lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
| SegHist lvl SegSpace [HistOp lore] [Type] (KernelBody lore)
deriving (SegOp lvl lore -> SegOp lvl lore -> Bool
(SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> Eq (SegOp lvl lore)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
/= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c/= :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
== :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c== :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
Eq, Eq (SegOp lvl lore)
Eq (SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> Ordering)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> Ord (SegOp lvl lore)
SegOp lvl lore -> SegOp lvl lore -> Bool
SegOp lvl lore -> SegOp lvl lore -> Ordering
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {lore}.
(Decorations lore, Ord lvl) =>
Eq (SegOp lvl lore)
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
min :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmin :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
max :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmax :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
>= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c>= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
> :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c> :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
<= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c<= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
< :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c< :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
compare :: SegOp lvl lore -> SegOp lvl lore -> Ordering
$ccompare :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
Ord, Int -> SegOp lvl lore -> ShowS
[SegOp lvl lore] -> ShowS
SegOp lvl lore -> String
(Int -> SegOp lvl lore -> ShowS)
-> (SegOp lvl lore -> String)
-> ([SegOp lvl lore] -> ShowS)
-> Show (SegOp lvl lore)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showList :: [SegOp lvl lore] -> ShowS
$cshowList :: forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
show :: SegOp lvl lore -> String
$cshow :: forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showsPrec :: Int -> SegOp lvl lore -> ShowS
$cshowsPrec :: forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
Show)
segLevel :: SegOp lvl lore -> lvl
segLevel :: forall lvl lore. SegOp lvl lore -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segSpace :: SegOp lvl lore -> SegSpace
segSpace :: forall lvl lore. SegOp lvl lore -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t (Returns ResultManifest
_ SubExp
_) =
(SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)
segOpType :: SegOp lvl lore -> [Type]
segOpType :: forall lvl lore. SegOp lvl lore -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody lore
kbody) =
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) =
[Type]
red_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegBinOp lore
op <- [SegBinOp lore]
reds
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) =
[Type]
scan_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
scan_ts :: [Type]
scan_ts = do
SegBinOp lore
op <- [SegBinOp lore]
scans
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp lore]
ops [Type]
_ KernelBody lore
_) = do
HistOp lore
op <- [HistOp lore]
ops
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op]) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp lore -> ShapeBase SubExp
forall lore. HistOp lore -> ShapeBase SubExp
histShape HistOp lore
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
where
dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lvl lore) where
opType :: forall t (m :: * -> *).
HasScope t m =>
SegOp lvl lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl lore -> [ExtType]) -> SegOp lvl lore -> m [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType
instance
(ASTLore lore, Aliased lore, ASTConstraints lvl) =>
AliasedOp (SegOp lvl lore)
where
opAliases :: SegOp lvl lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [Names]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType
consumedInOp :: SegOp lvl lore -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
kbody) =
[VName] -> Names
namesFromList ((HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
typeCheckSegOp ::
TC.Checkable lore =>
(lvl -> TC.TypeM lore ()) ->
SegOp lvl (Aliases lore) ->
TC.TypeM lore ()
typeCheckSegOp :: forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases lore)
kbody) = do
lvl -> TypeM lore ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases lore)
kbody
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
reds [Type]
ts KernelBody (Aliases lore)
body) = do
lvl -> TypeM lore ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases lore)
body
where
reds' :: [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
reds' =
[Lambda (Aliases lore)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
reds)
((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
reds)
((SegBinOp (Aliases lore) -> ShapeBase SubExp)
-> [SegBinOp (Aliases lore)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases lore)]
reds)
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
scans [Type]
ts KernelBody (Aliases lore)
body) = do
lvl -> TypeM lore ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases lore)
body
where
scans' :: [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
scans' =
[Lambda (Aliases lore)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
scans)
((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
scans)
((SegBinOp (Aliases lore) -> ShapeBase SubExp)
-> [SegBinOp (Aliases lore)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases lore)]
scans)
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases lore)]
ops [Type]
ts KernelBody (Aliases lore)
kbody) = do
lvl -> TypeM lore ()
checkLvl lvl
lvl
SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
(Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase SubExp) u -> TypeM lore ()
TC.checkType [Type]
ts
Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]])
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases lore)
op) -> do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
[Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"SegHist operator has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t
let dest_shape :: ShapeBase SubExp
dest_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp
dest_w]) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
[(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape] VName
dest
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest
[Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"SegHist body has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
where
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed ::
TC.Checkable lore =>
SegSpace ->
[(Lambda (Aliases lore), [SubExp], Shape)] ->
[Type] ->
KernelBody (Aliases lore) ->
TC.TypeM lore ()
checkScanRed :: forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases lore)
kbody = do
SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
(Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase SubExp) u -> TypeM lore ()
TC.checkType [Type]
ts
Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
-> ((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
-> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases lore), [SubExp], ShapeBase SubExp)]
ops (((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
-> TypeM lore [Type])
-> TypeM lore [[Type]])
-> ((Lambda (Aliases lore), [SubExp], ShapeBase SubExp)
-> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases lore)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
[Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"wrong type for operator or neutral elements."
[Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Wrong return for body (does not match neutral elements; expected "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody
data SegOpMapper lvl flore tlore m = SegOpMapper
{ forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore),
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore),
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
}
identitySegOpMapper :: Monad m => SegOpMapper lvl lore lore m
identitySegOpMapper :: forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper =
SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSegOpLambda :: Lambda lore -> m (Lambda lore)
mapOnSegOpLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSegOpBody :: KernelBody lore -> m (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> m (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
}
mapOnSegSpace ::
Monad f =>
SegOpMapper lvl flore tlore f ->
SegSpace ->
f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp))
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore f -> SubExp -> f SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore f
tv) [(VName, SubExp)]
dims
mapSegBinOp ::
Monad m =>
SegOpMapper lvl flore tlore m ->
SegBinOp flore ->
m (SegBinOp tlore)
mapSegBinOp :: forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv (SegBinOp Commutativity
comm Lambda flore
red_op [SubExp]
nes ShapeBase SubExp
shape) =
Commutativity
-> Lambda tlore -> [SubExp] -> ShapeBase SubExp -> SegBinOp tlore
forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp Commutativity
comm
(Lambda tlore -> [SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
-> m (Lambda tlore)
-> m ([SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
red_op
m ([SubExp] -> ShapeBase SubExp -> SegBinOp tlore)
-> m [SubExp] -> m (ShapeBase SubExp -> SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
m (ShapeBase SubExp -> SegBinOp tlore)
-> m (ShapeBase SubExp) -> m (SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
mapSegOpM ::
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m ->
SegOp lvl flore ->
m (SegOp lvl tlore)
mapSegOpM :: forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody flore
body) =
lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap
(lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace -> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> Type -> m Type
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp flore]
reds [Type]
ts KernelBody flore
lam) =
lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed
(lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
m (SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
-> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
m ([SegBinOp tlore]
-> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
reds
m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
lam
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp flore]
scans [Type]
ts KernelBody flore
body) =
lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan
(lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
m (SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
-> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
m ([SegBinOp tlore]
-> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
scans
m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegHist lvl
lvl SegSpace
space [HistOp flore]
ops [Type]
ts KernelBody flore
body) =
lvl
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist
(lvl
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
-> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
m (SegSpace
-> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace
-> m ([HistOp tlore]
-> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [HistOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp flore -> m (HistOp tlore)
onHistOp [HistOp flore]
ops
m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
where
onHistOp :: HistOp flore -> m (HistOp tlore)
onHistOp (HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda flore
op) =
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda tlore
-> HistOp tlore
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda lore
-> HistOp lore
HistOp (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda tlore
-> HistOp tlore)
-> m SubExp
-> m (SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda tlore
-> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
w
m (SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda tlore
-> HistOp tlore)
-> m SubExp
-> m ([VName]
-> [SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
rf
m ([VName]
-> [SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m [VName]
-> m ([SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> VName -> m VName
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl flore tlore m
tv) [VName]
arrs
m ([SubExp] -> ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m [SubExp]
-> m (ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
m (ShapeBase SubExp -> Lambda tlore -> HistOp tlore)
-> m (ShapeBase SubExp) -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
op
mapOnSegOpType ::
Monad m =>
SegOpMapper lvl flore tlore m ->
Type ->
m Type
mapOnSegOpType :: forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv t :: Type
t@Prim {} = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl flore tlore m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
(VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m VName
-> m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> VName -> m VName
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl flore tlore m
tv VName
acc
m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m ([Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) ShapeBase SubExp
ispace
m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((ShapeBase SubExp -> m (ShapeBase SubExp))
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv)) NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl flore tlore m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (ShapeBase SubExp -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) ShapeBase SubExp
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s
instance
(ASTLore lore, Substitute lvl) =>
Substitute (SegOp lvl lore)
where
substituteNames :: Map VName VName -> SegOp lvl lore -> SegOp lvl lore
substituteNames Map VName VName
subst = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl lore -> Identity (SegOp lvl lore))
-> SegOp lvl lore
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl lore lore Identity
-> SegOp lvl lore -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore Identity
substitute
where
substitute :: SegOpMapper lvl lore lore Identity
substitute =
SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSegOpLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpBody :: KernelBody lore -> Identity (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody lore -> KernelBody lore)
-> KernelBody lore
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody lore -> KernelBody lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance
(ASTLore lore, ASTConstraints lvl) =>
Rename (SegOp lvl lore)
where
rename :: SegOp lvl lore -> RenameM (SegOp lvl lore)
rename = SegOpMapper lvl lore lore RenameM
-> SegOp lvl lore -> RenameM (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore RenameM
renamer
where
renamer :: SegOpMapper lvl lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (KernelBody lore -> RenameM (KernelBody lore))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl lore lore RenameM
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename KernelBody lore -> RenameM (KernelBody lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename
instance
(ASTLore lore, FreeIn (LParamInfo lore), FreeIn lvl) =>
FreeIn (SegOp lvl lore)
where
freeIn' :: SegOp lvl lore -> FV
freeIn' SegOp lvl lore
e = (State FV (SegOp lvl lore) -> FV -> FV)
-> FV -> State FV (SegOp lvl lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl lore) -> FV)
-> State FV (SegOp lvl lore) -> FV
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl lore lore (StateT FV Identity)
-> SegOp lvl lore -> State FV (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore (StateT FV Identity)
free SegOp lvl lore
e
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
free :: SegOpMapper lvl lore lore (StateT FV Identity)
free =
SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSegOpLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpBody :: KernelBody lore -> StateT FV Identity (KernelBody lore)
mapOnSegOpBody = (KernelBody lore -> FV)
-> KernelBody lore -> StateT FV Identity (KernelBody lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance OpMetrics (Op lore) => OpMetrics (SegOp lvl lore) where
opMetrics :: SegOp lvl lore -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
reds [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
reds
KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
scans [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
scans
KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops
KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
instance Pretty SegSpace where
ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) =
Doc -> Doc
parens
( [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do
(VName
i, SubExp
d) <- [(VName, SubExp)]
dims
Doc -> [Doc]
forall (m :: * -> *) a. Monad m => a -> m a
return (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d
)
Doc -> Doc -> Doc
<+> Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)
instance PrettyLore lore => Pretty (SegBinOp lore) where
ppr :: SegBinOp lore -> Doc
ppr (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes ShapeBase SubExp
shape) =
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
where
comm' :: Doc
comm' = case Commutativity
comm of
Commutativity
Commutative -> String -> Doc
text String
"commutative "
Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty
instance (PrettyLore lore, PP.Pretty lvl) => PP.Pretty (SegOp lvl lore) where
ppr :: SegOp lvl lore -> Doc
ppr (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segmap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
<+> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
ppr (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segred" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
reds)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
ppr (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segscan" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
scans)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
ppr (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"seghist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall {lore}. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
where
ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda lore
op) =
SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> ShapeBase SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeBase SubExp
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op
instance
( ASTLore lore,
ASTLore (Aliases lore),
CanBeAliased (Op lore),
ASTConstraints lvl
) =>
CanBeAliased (SegOp lvl lore)
where
type OpWithAliases (SegOp lvl lore) = SegOp lvl (Aliases lore)
addOpAliases :: AliasTable -> SegOp lvl lore -> OpWithAliases (SegOp lvl lore)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore))
-> (SegOp lvl lore -> Identity (SegOp lvl (Aliases lore)))
-> SegOp lvl lore
-> SegOp lvl (Aliases lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl lore (Aliases lore) Identity
-> SegOp lvl lore -> Identity (SegOp lvl (Aliases lore))
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore (Aliases lore) Identity
alias
where
alias :: SegOpMapper lvl lore (Aliases lore) Identity
alias =
(SubExp -> Identity SubExp)
-> (Lambda lore -> Identity (Lambda (Aliases lore)))
-> (KernelBody lore -> Identity (KernelBody (Aliases lore)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl lore (Aliases lore) Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
(Lambda (Aliases lore) -> Identity (Lambda (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore)))
-> (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore
-> Identity (Lambda (Aliases lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases)
(KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore)))
-> (KernelBody lore -> KernelBody (Aliases lore))
-> KernelBody lore
-> Identity (KernelBody (Aliases lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody AliasTable
aliases)
VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
removeOpAliases :: OpWithAliases (SegOp lvl lore) -> SegOp lvl lore
removeOpAliases = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Aliases lore)
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Aliases lore) lore Identity
-> SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Aliases lore) lore Identity
forall {lvl}. SegOpMapper lvl (Aliases lore) lore Identity
remove
where
remove :: SegOpMapper lvl (Aliases lore) lore Identity
remove =
(SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (KernelBody (Aliases lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Aliases lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
(Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases)
(KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Aliases lore) -> KernelBody lore)
-> KernelBody (Aliases lore)
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Aliases lore) -> KernelBody lore
forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases)
VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
instance
(CanBeWise (Op lore), ASTLore lore, ASTConstraints lvl) =>
CanBeWise (SegOp lvl lore)
where
type OpWithWisdom (SegOp lvl lore) = SegOp lvl (Wise lore)
removeOpWisdom :: OpWithWisdom (SegOp lvl lore) -> SegOp lvl lore
removeOpWisdom = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Wise lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Wise lore)
-> SegOp lvl lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Wise lore) lore Identity
-> SegOp lvl (Wise lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Wise lore) lore Identity
forall {lvl}. SegOpMapper lvl (Wise lore) lore Identity
remove
where
remove :: SegOpMapper lvl (Wise lore) lore Identity
remove =
(SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (KernelBody (Wise lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Wise lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper
SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
(Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom)
(KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Wise lore) -> KernelBody lore)
-> KernelBody (Wise lore)
-> Identity (KernelBody lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Wise lore) -> KernelBody lore
forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom)
VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
instance ASTLore lore => ST.IndexOp (SegOp lvl lore) where
indexOp :: forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> SegOp lvl lore -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody lore
kbody) [TPrimExp Int64 VName]
is = do
Returns ResultManifest
ResultMaySimplify SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> PrimExp VName -> Indexed
ST.Indexed Certificates
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm lore -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm lore) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm lore) -> Map VName Indexed)
-> Seq (Stm lore) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
case SubExp
se of
Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is
expandIndexedTable :: Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm lore
stm
| [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
Just (PrimExp VName
pe, Certificates
cs) <-
WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certificates -> PrimExp VName -> Indexed
ST.Indexed (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
VName
arr VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable,
Just ([DimIndex (PrimExp VName)]
slice', Certificates
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx =
Certificates -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
(Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs)
VName
arr
(Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex (PrimExp VName) -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex (PrimExp VName)] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((PrimExp VName -> TPrimExp Int64 VName)
-> DimIndex (PrimExp VName) -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64) [DimIndex (PrimExp VName)]
slice') [TPrimExp Int64 VName]
excess_is)
in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table =
WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates))
-> (Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)])
-> Slice SubExp
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName)))
-> Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> WriterT Certificates Maybe (PrimExp VName))
-> DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certificates Maybe (PrimExp VName))
-> SubExp -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table)))
asPrimExp :: Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable lore
_ Int
_ SegOp lvl lore
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance
(ASTLore lore, ASTConstraints lvl) =>
IsOp (SegOp lvl lore)
where
cheapOp :: SegOp lvl lore -> Bool
cheapOp SegOp lvl lore
_ = Bool
False
safeOp :: SegOp lvl lore -> Bool
safeOp SegOp lvl lore
_ = Bool
True
instance Engine.Simplifiable SplitOrdering where
simplify :: forall lore.
SimplifiableLore lore =>
SplitOrdering -> SimpleM lore SplitOrdering
simplify SplitOrdering
SplitContiguous =
SplitOrdering -> SimpleM lore SplitOrdering
forall (m :: * -> *) a. Monad m => a -> m a
return SplitOrdering
SplitContiguous
simplify (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM lore SubExp -> SimpleM lore SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
stride
instance Engine.Simplifiable SegSpace where
simplify :: forall lore.
SimplifiableLore lore =>
SegSpace -> SimpleM lore SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM lore [(VName, SubExp)] -> SimpleM lore SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM lore (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM lore [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM lore SubExp)
-> (VName, SubExp) -> SimpleM lore (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify) [(VName, SubExp)]
dims
instance Engine.Simplifiable KernelResult where
simplify :: forall lore.
SimplifiableLore lore =>
KernelResult -> SimpleM lore KernelResult
simplify (Returns ResultManifest
manifest SubExp
what) =
ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (SubExp -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
what
simplify (WriteReturns ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
ShapeBase SubExp
-> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns (ShapeBase SubExp
-> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore (ShapeBase SubExp)
-> SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
ws SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore VName
-> SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
a SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore [(Slice SubExp, SubExp)]
-> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM lore [(Slice SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(Slice SubExp, SubExp)]
res
simplify (ConcatReturns SplitOrdering
o SubExp
w SubExp
pte VName
what) =
SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
(SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o
SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp
-> SimpleM lore (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
SimpleM lore (SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
pte
SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
simplify (TileReturns [(SubExp, SubExp)]
dims VName
what) =
[(SubExp, SubExp)] -> VName -> KernelResult
TileReturns ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp)] -> SimpleM lore [(SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
simplify (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
[(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp, SubExp)]
-> SimpleM lore [(SubExp, SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
mkWiseKernelBody ::
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore ->
Stms (Wise lore) ->
[KernelResult] ->
KernelBody (Wise lore)
mkWiseKernelBody :: forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody BodyDec lore
dec Stms (Wise lore)
bnds [KernelResult]
res =
let Body BodyDec (Wise lore)
dec' Stms (Wise lore)
_ [SubExp]
_ = BodyDec lore -> Stms (Wise lore) -> [SubExp] -> BodyT (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
mkWiseBody BodyDec lore
dec Stms (Wise lore)
bnds [SubExp]
res_vs
in BodyDec (Wise lore)
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Wise lore)
dec' Stms (Wise lore)
bnds [KernelResult]
res
where
res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res
mkKernelBodyM ::
MonadBinder m =>
Stms (Lore m) ->
[KernelResult] ->
m (KernelBody (Lore m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms (Lore m)
stms [KernelResult]
kres = do
Body BodyDec (Lore m)
dec' Stms (Lore m)
_ [SubExp]
_ <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
stms [SubExp]
res_ses
KernelBody (Lore m) -> m (KernelBody (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Lore m) -> m (KernelBody (Lore m)))
-> KernelBody (Lore m) -> m (KernelBody (Lore m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Lore m)
-> Stms (Lore m) -> [KernelResult] -> KernelBody (Lore m)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Lore m)
dec' Stms (Lore m)
stms [KernelResult]
kres
where
res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres
simplifyKernelBody ::
(Engine.SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace ->
KernelBody lore ->
Engine.SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody :: forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) = do
BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
Engine.asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
Engine.blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
Engine.envHoistBlockers
((Stms (Wise lore)
body_stms, [KernelResult]
body_res), Stms (Wise lore)
hoisted) <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable ((SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. VName -> SymbolTable lore -> SymbolTable lore
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
(SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable)
(SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
(SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> (SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
Engine.enterLoop
(SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
Engine.blockIf
( Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
Engine.hasFree Names
bound_here
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isOp
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
par_blocker
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isConsumed
)
(SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Stms lore
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
Engine.simplifyStms Stms lore
stms (SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult]))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall a b. (a -> b) -> a -> b
$ do
[KernelResult]
res' <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. Names -> SymbolTable lore -> SymbolTable lore
ST.hideCertified (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (SimpleM lore [KernelResult] -> SimpleM lore [KernelResult])
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SimpleM lore KernelResult)
-> [KernelResult] -> SimpleM lore [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM lore KernelResult
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [KernelResult]
res
SimplifiedBody lore [KernelResult]
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall (m :: * -> *) a. Monad m => a -> m a
return (([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res'), Stms (Wise lore)
forall a. Monoid a => a
mempty)
(KernelBody (Wise lore), Stms (Wise lore))
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise lore)
body_stms [KernelResult]
body_res, Stms (Wise lore)
hoisted)
where
scope_vtable :: SymbolTable (Wise lore)
scope_vtable = SegSpace -> SymbolTable (Wise lore)
forall lore. ASTLore lore => SegSpace -> SymbolTable lore
segSpaceSymbolTable SegSpace
space
bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
[VName
arr]
consumedInResult KernelResult
_ =
[]
segSpaceSymbolTable :: ASTLore lore => SegSpace -> ST.SymbolTable lore
segSpaceSymbolTable :: forall lore. ASTLore lore => SegSpace -> SymbolTable lore
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
(SymbolTable lore -> (VName, SubExp) -> SymbolTable lore)
-> SymbolTable lore -> [(VName, SubExp)] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
forall {lore}.
ASTLore lore =>
SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f (Scope lore -> SymbolTable lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope lore -> SymbolTable lore) -> Scope lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo lore -> Scope lore
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo lore -> Scope lore) -> NameInfo lore -> Scope lore
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
where
f :: SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f SymbolTable lore
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable lore
vtable
simplifySegBinOp ::
Engine.SimplifiableLore lore =>
SegBinOp lore ->
Engine.SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp :: forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes ShapeBase SubExp
shape) = do
(Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
ShapeBase SubExp
shape' <- ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
shape
[SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
(SegBinOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity
-> Lambda (Wise lore)
-> [SubExp]
-> ShapeBase SubExp
-> SegBinOp (Wise lore)
forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp Commutativity
comm Lambda (Wise lore)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise lore)
hoisted)
simplifySegOp ::
( Engine.SimplifiableLore lore,
BodyDec lore ~ (),
Engine.Simplifiable lvl
) =>
SegOp lvl lore ->
Engine.SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp :: forall lore lvl.
(SimplifiableLore lore, BodyDec lore ~ (), Simplifiable lvl) =>
SegOp lvl lore
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
(SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise lore)
kbody',
Stms (Wise lore)
body_hoisted
)
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise lore)]
reds', [Stms (Wise lore)]
reds_hoisted) <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
reds
(KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
(SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
reds' [Type]
ts' KernelBody (Wise lore)
kbody',
[Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
reds_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
)
where
scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise lore)]
scans', [Stms (Wise lore)]
scans_hoisted) <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
scans
(KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
(SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
scans' [Type]
ts' KernelBody (Wise lore)
kbody',
[Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
scans_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
)
where
scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([HistOp (Wise lore)]
ops', [Stms (Wise lore)]
ops_hoisted) <- ([(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
[HistOp lore]
-> (HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp lore]
ops ((HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))])
-> (HistOp lore
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$
\(HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda lore
lam) -> do
SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
SubExp
rf' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
rf
[VName]
arrs' <- [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arrs
[SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
ShapeBase SubExp
dims' <- ShapeBase SubExp -> SimpleM lore (ShapeBase SubExp)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase SubExp
dims
(Lambda (Wise lore)
lam', Stms (Wise lore)
op_hoisted) <-
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
(HistOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda (Wise lore)
-> HistOp (Wise lore)
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda lore
-> HistOp lore
HistOp SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise lore)
lam',
Stms (Wise lore)
op_hoisted
)
(KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
(SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
( lvl
-> SegSpace
-> [HistOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise lore)]
ops' [Type]
ts' KernelBody (Wise lore)
kbody',
[Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
ops_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted
)
where
scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope
class HasSegOp lore where
type SegOpLevel lore
asSegOp :: Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
segOp :: SegOp (SegOpLevel lore) lore -> Op lore
segOpRules ::
(HasSegOp lore, BinderOps lore, Bindable lore) =>
RuleBook lore
segOpRules :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
RuleBook lore
segOpRules =
[TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (TopDown lore)
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown] [RuleOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (BottomUp lore)
forall lore. (HasSegOp lore, BinderOps lore) => BottomUpRuleOp lore
segOpRuleBottomUp]
segOpRuleTopDown ::
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
| Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
TopDown lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
| Bool
otherwise =
Rule lore
forall lore. Rule lore
Skip
segOpRuleBottomUp ::
(HasSegOp lore, BinderOps lore) =>
BottomUpRuleOp lore
segOpRuleBottomUp :: forall lore. (HasSegOp lore, BinderOps lore) => BottomUpRuleOp lore
segOpRuleBottomUp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
| Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
BottomUp lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore) =>
(SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
| Bool
otherwise =
Rule lore
forall lore. Rule lore
Skip
topDownSegOp ::
(HasSegOp lore, BinderOps lore, Bindable lore) =>
ST.SymbolTable lore ->
Pattern lore ->
StmAux (ExpDec lore) ->
SegOp (SegOpLevel lore) lore ->
Rule lore
topDownSegOp :: forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp SymbolTable lore
vtable (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts (KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
([Type]
ts', [PatElemT (LetDec lore)]
kpes', [KernelResult]
kres') <-
[(Type, PatElemT (LetDec lore), KernelResult)]
-> ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElemT (LetDec lore), KernelResult)]
-> ([Type], [PatElemT (LetDec lore)], [KernelResult]))
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool)
-> [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult ([Type]
-> [PatElemT (LetDec lore)]
-> [KernelResult]
-> [(Type, PatElemT (LetDec lore), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElemT (LetDec lore)]
kpes [KernelResult]
kres)
Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres')
RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
KernelBody lore
kbody <- Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms [KernelResult]
kres'
Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
PatternT (LetDec lore)
-> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts' KernelBody lore
kbody
where
isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
isInvariant (Var VName
v) = Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v SymbolTable lore
vtable
checkForInvarianceResult :: (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult (Type
_, PatElemT (LetDec lore)
pe, Returns ResultManifest
rm SubExp
se)
| ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
SubExp -> Bool
isInvariant SubExp
se = do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
checkForInvarianceResult (Type, PatElemT (LetDec lore), KernelResult)
_ =
Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
topDownSegOp SymbolTable lore
_ (Pattern [] [PatElemT (LetDec lore)]
pes) StmAux (ExpDec lore)
_ (SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
ts KernelBody lore
kbody)
| [SegBinOp lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp lore]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
[[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings <-
((SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> Bool)
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> Bool
forall {lore} {b} {lore} {b}.
(SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]])
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
[SegBinOp lore]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp lore]
ops ([[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])])
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
[Int]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp lore]
ops) ([(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]])
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (LetDec lore)]
red_pes [Type]
red_ts [KernelResult]
red_res,
([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Bool)
-> [[(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
-> Int)
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ([SegBinOp lore]
ops', [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux) = [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
[[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
[[(PatElemT (LetDec lore), Type, KernelResult)]]))
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
[[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Maybe
(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)]))
-> [[(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]]
-> [(SegBinOp lore,
[(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Maybe
(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
forall {lore} {a}.
Bindable lore =>
[(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings
([PatElemT (LetDec lore)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult]))
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux
pes' :: [PatElemT (LetDec lore)]
pes' = [PatElemT (LetDec lore)]
red_pes' [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec lore)]
map_pes
ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
kbody' :: KernelBody lore
kbody' = KernelBody lore
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pes') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops' [Type]
ts' KernelBody lore
kbody'
where
([PatElemT (LetDec lore)]
red_pes, [PatElemT (LetDec lore)]
map_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [PatElemT (LetDec lore)]
pes
([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [Type]
ts
([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
sameShape :: (SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape (SegBinOp lore
op1, b
_) (SegBinOp lore
op2, b
_) = SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op2
combineOps :: [(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [] = Maybe (SegBinOp lore, [a])
forall a. Maybe a
Nothing
combineOps ((SegBinOp lore, [a])
x : [(SegBinOp lore, [a])]
xs) = (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a. a -> Maybe a
Just ((SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a]))
-> (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a]))
-> (SegBinOp lore, [a])
-> [(SegBinOp lore, [a])]
-> (SegBinOp lore, [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
forall {lore} {a}.
Bindable lore =>
(SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore, [a])
x [(SegBinOp lore, [a])]
xs
combine :: (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore
op1, [a]
op1_aux) (SegBinOp lore
op2, [a]
op2_aux) =
let lam1 :: Lambda lore
lam1 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op1
lam2 :: Lambda lore
lam2 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op2
([Param (LParamInfo lore)]
op1_xparams, [Param (LParamInfo lore)]
op1_yparams) =
Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1)) ([Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam1
([Param (LParamInfo lore)]
op2_xparams, [Param (LParamInfo lore)]
op2_yparams) =
Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2)) ([Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam2
lam :: Lambda lore
lam =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [Param (LParamInfo lore)]
lambdaParams =
[Param (LParamInfo lore)]
op1_xparams [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_xparams
[Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op1_yparams
[Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_yparams,
lambdaReturnType :: [Type]
lambdaReturnType = Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam2,
lambdaBody :: BodyT lore
lambdaBody =
Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)) ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$
BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)
}
in ( SegBinOp :: forall lore.
Commutativity
-> Lambda lore -> [SubExp] -> ShapeBase SubExp -> SegBinOp lore
SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op2,
segBinOpLambda :: Lambda lore
segBinOpLambda = Lambda lore
lam,
segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2,
segBinOpShape :: ShapeBase SubExp
segBinOpShape = SegBinOp lore -> ShapeBase SubExp
forall lore. SegBinOp lore -> ShapeBase SubExp
segBinOpShape SegBinOp lore
op1
},
[a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
)
topDownSegOp SymbolTable lore
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip
segOpGuts ::
SegOp (SegOpLevel lore) lore ->
( [Type],
KernelBody lore,
Int,
[Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
)
segOpGuts :: forall lore.
SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
[Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
segOpGuts (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
kts KernelBody lore
body) =
([Type]
kts, KernelBody lore
body, Int
0, SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
kts KernelBody lore
body) =
([Type]
kts, KernelBody lore
body, [SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops)
segOpGuts (SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
kts KernelBody lore
body) =
([Type]
kts, KernelBody lore
body, [SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops)
segOpGuts (SegHist SegOpLevel lore
lvl SegSpace
space [HistOp lore]
ops [Type]
kts KernelBody lore
body) =
([Type]
kts, KernelBody lore
body, [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Int) -> [HistOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp lore -> [VName]) -> HistOp lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp lore]
ops, SegOpLevel lore
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist SegOpLevel lore
lvl SegSpace
space [HistOp lore]
ops)
bottomUpSegOp ::
(HasSegOp lore, BinderOps lore) =>
(ST.SymbolTable lore, UT.UsageTable) ->
Pattern lore ->
StmAux (ExpDec lore) ->
SegOp (SegOpLevel lore) lore ->
Rule lore
bottomUpSegOp :: forall lore.
(HasSegOp lore, BinderOps lore) =>
(SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
segop = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') <-
Scope lore
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
(([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stms lore
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes, [Type]
kts, [KernelResult]
kres, Stms lore
forall a. Monoid a => a
mempty) Stms lore
kstms
Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
([PatElemT (LetDec lore)]
kpes' [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec lore)]
kpes)
RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
KernelBody lore
kbody <-
Scope lore
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore))
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$
Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms' [KernelResult]
kres'
Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore)
-> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> PatternT (LetDec lore)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
mk_segop [Type]
kts' KernelBody lore
kbody
where
([Type]
kts, KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore
mk_segop) =
SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
[Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
forall lore.
SegOp (SegOpLevel lore) lore
-> ([Type], KernelBody lore, Int,
[Type] -> KernelBody lore -> SegOp (SegOpLevel lore) lore)
segOpGuts SegOp (SegOpLevel lore) lore
segop
free_in_kstms :: Names
free_in_kstms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
kstms
space :: SegSpace
space = SegOp (SegOpLevel lore) lore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp (SegOpLevel lore) lore
segop
sliceWithGtidsFixed :: Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm
| Let PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm lore
stm,
Slice SubExp
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> Slice SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
Slice SubExp
space_slice Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp
slice,
Slice SubExp
remaining_slice <- Int -> Slice SubExp -> Slice SubExp
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
space_slice) Slice SubExp
slice,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool)
-> (VName -> Maybe (Entry lore)) -> VName -> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable lore -> Maybe (Entry lore))
-> SymbolTable lore -> VName -> Maybe (Entry lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup SymbolTable lore
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
(Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
| Bool
otherwise =
Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing
distribute :: ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm
| Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ Exp lore
_ <- Stm lore
stm,
Just (Slice SubExp
remaining_slice, VName
arr) <- Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm,
Just (PatElemT (LetDec lore)
kpe, [PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> PatElemT (LetDec lore)
-> Maybe
(PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
[KernelResult])
isResult [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (LetDec lore)
pe = do
let outer_slice :: Slice SubExp
outer_slice =
(SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map
( \SubExp
d ->
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
d
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
)
([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
index :: PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe' =
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe'] (Exp lore -> RuleM lore ())
-> (Slice SubExp -> Exp lore) -> Slice SubExp -> RuleM lore ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM lore ()) -> Slice SubExp -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Slice SubExp
outer_slice Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. Semigroup a => a -> a -> a
<> Slice SubExp
remaining_slice
if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
then do
VName
precopy <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe {patElemName :: VName
patElemName = VName
precopy}
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
else PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe
([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
( [PatElemT (LetDec lore)]
kpes'',
[Type]
kts'',
[KernelResult]
kres'',
if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
then Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm
else Stms lore
kstms'
)
distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm =
([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm)
isResult :: [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> PatElemT (LetDec lore)
-> Maybe
(PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
[KernelResult])
isResult [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (LetDec lore)
pe =
case ((PatElemT (LetDec lore), Type, KernelResult) -> Bool)
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([(PatElemT (LetDec lore), Type, KernelResult)],
[(PatElemT (LetDec lore), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), Type, KernelResult) -> Bool
matches ([(PatElemT (LetDec lore), Type, KernelResult)]
-> ([(PatElemT (LetDec lore), Type, KernelResult)],
[(PatElemT (LetDec lore), Type, KernelResult)]))
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([(PatElemT (LetDec lore), Type, KernelResult)],
[(PatElemT (LetDec lore), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' of
([(PatElemT (LetDec lore)
kpe, Type
_, KernelResult
_)], [(PatElemT (LetDec lore), Type, KernelResult)]
kpes_and_kres)
| Just Int
i <- PatElemT (LetDec lore) -> [PatElemT (LetDec lore)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElemT (LetDec lore)
kpe [PatElemT (LetDec lore)]
kpes,
Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
([PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElemT (LetDec lore), Type, KernelResult)]
kpes_and_kres ->
(PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
[KernelResult])
-> Maybe
(PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
[KernelResult])
forall a. a -> Maybe a
Just (PatElemT (LetDec lore)
kpe, [PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
([(PatElemT (LetDec lore), Type, KernelResult)],
[(PatElemT (LetDec lore), Type, KernelResult)])
_ -> Maybe
(PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
[KernelResult])
forall a. Maybe a
Nothing
where
matches :: (PatElemT (LetDec lore), Type, KernelResult) -> Bool
matches (PatElemT (LetDec lore)
_, Type
_, Returns ResultManifest
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe
matches (PatElemT (LetDec lore), Type, KernelResult)
_ = Bool
False
bottomUpSegOp (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip
kernelBodyReturns ::
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns :: forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall {m :: * -> *} {lore}.
(Monad m, HasScope lore m, AllocOp (Op lore), ASTLore lore,
OpReturns lore, LetDec lore ~ LetDecMem,
LParamInfo lore ~ LetDecMem, RetType lore ~ RetTypeMem,
FParamInfo lore ~ FParamMem, BranchType lore ~ BranchTypeMem) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody lore -> [KernelResult])
-> KernelBody lore
-> [ExpReturns]
-> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult
where
correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
arr
correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return ExpReturns
ret
segOpReturns ::
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore ->
m [ExpReturns]
segOpReturns :: forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns k :: SegOp lvl lore
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns k :: SegOp lvl lore
k@(SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns k :: SegOp lvl lore
k@(SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
_) =
[[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp lore -> m [ExpReturns])
-> [HistOp lore] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp lore -> [VName]) -> HistOp lore -> m [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp lore]
ops