{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.Kernels.Kernel
(
SizeOp (..),
HostOp (..),
typeCheckHostOp,
SegLevel (..),
module Futhark.IR.Kernels.Sizes,
module Futhark.IR.SegOp,
)
where
import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.Kernels.Sizes
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util.Pretty
( commasep,
parens,
ppr,
text,
(<+>),
)
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
data SegLevel
= SegThread
{ SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp,
SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp,
SegLevel -> SegVirt
segVirt :: SegVirt
}
| SegGroup
{ segNumGroups :: Count NumGroups SubExp,
segGroupSize :: Count GroupSize SubExp,
segVirt :: SegVirt
}
deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)
instance PP.Pretty SegLevel where
ppr :: SegLevel -> Doc
ppr SegLevel
lvl =
Doc -> Doc
PP.parens
( Doc
lvl' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
Doc -> Doc -> Doc
<+> String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
Doc -> Doc -> Doc
<+> String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
virt
)
where
lvl' :: Doc
lvl' = case SegLevel
lvl of
SegThread {} -> Doc
"thread"
SegGroup {} -> Doc
"group"
virt :: Doc
virt = case SegLevel -> SegVirt
segVirt SegLevel
lvl of
SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
SegVirt
SegNoVirtFull -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"full"
SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise"
instance Engine.Simplifiable SegLevel where
simplify :: SegLevel -> SimpleM lore SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups
SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size
SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups
SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size
SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
instance Substitute SegLevel where
substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
(Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
(Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
SegVirt
virt
substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
(Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
(Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
SegVirt
virt
instance Rename SegLevel where
rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn SegLevel where
freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
data SizeOp
=
SplitSpace SplitOrdering SubExp SubExp SubExp
|
GetSize Name SizeClass
|
GetSizeMax SizeClass
|
CmpSizeLe Name SizeClass SubExp
|
CalcNumGroups SubExp Name SubExp
deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)
instance Substitute SizeOp where
substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
Name
max_num_groups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
substituteNames Map VName VName
_ SizeOp
op = SizeOp
op
instance Rename SizeOp where
rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x
instance IsOp SizeOp where
safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True
instance TypedOp SizeOp where
opType :: SizeOp -> m [ExtType]
opType SplitSpace {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType CmpSizeLe {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
opType CalcNumGroups {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
instance AliasedOp SizeOp where
opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty
instance FreeIn SizeOp where
freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty
instance PP.Pretty SizeOp where
ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
SplitContiguous SubExp
w SubExp
i SubExp
elems_per_thread) =
String -> Doc
text String
"split_space"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
ppr (SplitSpace (SplitStrided SubExp
stride) SubExp
w SubExp
i SubExp
elems_per_thread) =
String -> Doc
text String
"split_space_strided"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
ppr (GetSize Name
name SizeClass
size_class) =
String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (GetSizeMax SizeClass
size_class) =
String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
String -> Doc
text String
"cmp_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
Doc -> Doc -> Doc
<+> String -> Doc
text String
"<="
Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x
ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])
instance OpMetrics SizeOp where
opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace {} = Text -> MetricsM ()
seen Text
"SplitSpace"
opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
opMetrics CalcNumGroups {} = Text -> MetricsM ()
seen Text
"CalcNumGroups"
typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize {} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax {} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
group_size
data HostOp lore op
=
SegOp (SegOp SegLevel lore)
| SizeOp SizeOp
| OtherOp op
deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
$cp1Ord :: forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)
instance (ASTLore lore, Substitute op) => Substitute (HostOp lore op) where
substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel lore
op) =
SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel lore -> SegOp SegLevel lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel lore
op
substituteNames Map VName VName
substs (OtherOp op
op) =
op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
instance (ASTLore lore, Rename op) => Rename (HostOp lore op) where
rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> RenameM (SegOp SegLevel lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore -> RenameM (SegOp SegLevel lore)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel lore
op
rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op
instance (ASTLore lore, IsOp op) => IsOp (HostOp lore op) where
safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel lore
op
safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op
cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel lore
op
cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op
instance TypedOp op => TypedOp (HostOp lore op) where
opType :: HostOp lore op -> m [ExtType]
opType (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel lore
op
opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op
instance (Aliased lore, AliasedOp op, ASTLore lore) => AliasedOp (HostOp lore op) where
opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel lore
op
opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op
consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel lore
op
consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op
instance (ASTLore lore, FreeIn op) => FreeIn (HostOp lore op) where
freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel lore
op
freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
instance (CanBeAliased (Op lore), CanBeAliased op, ASTLore lore) => CanBeAliased (HostOp lore op) where
type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)
addOpAliases :: AliasTable -> HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> SegOp SegLevel lore -> OpWithAliases (SegOp SegLevel lore)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases SegOp SegLevel lore
op
addOpAliases AliasTable
aliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable -> op -> OpWithAliases op
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases op
op
addOpAliases AliasTable
_ (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp SegLevel lore)
SegOp SegLevel (Aliases lore)
op
removeOpAliases (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
removeOpAliases (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (CanBeWise (Op lore), CanBeWise op, ASTLore lore) => CanBeWise (HostOp lore op) where
type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)
removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp SegLevel lore)
SegOp SegLevel (Wise lore)
op
removeOpWisdom (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
removeOpWisdom (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (ASTLore lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
indexOp :: SymbolTable lore
-> Int -> HostOp lore op -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp SegLevel lore
op) [TPrimExp Int64 VName]
is = SymbolTable lore
-> Int
-> SegOp SegLevel lore
-> [TPrimExp Int64 VName]
-> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp SegLevel lore
op [TPrimExp Int64 VName]
is
indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [TPrimExp Int64 VName]
is = SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [TPrimExp Int64 VName]
is
indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp SegLevel lore
op
ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op
instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel lore
op
opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
checkSegLevel ::
TC.Checkable lore =>
Maybe SegLevel ->
SegLevel ->
TC.TypeM lore ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread {}) SegLevel
_ =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
| SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
| SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
| Bool
otherwise =
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckHostOp ::
TC.Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ()) ->
Maybe SegLevel ->
(op -> TC.TypeM lore ()) ->
HostOp (Aliases lore) op ->
TC.TypeM lore ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp SegLevel (Aliases lore)
op) =
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases lore) -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
(SegLevel -> TypeM lore ())
-> SegOp SegLevel (Aliases lore) -> TypeM lore ()
forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op