{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules
( standardRules,
removeUnnecessaryCopy,
)
where
import Control.Monad
import Data.Either
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util
topDownRules :: BinderOps lore => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules =
[ RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun,
RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf,
RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant
]
bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules =
[ RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult,
RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex
]
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. BinderOps lore => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules RuleBook lore -> RuleBook lore -> RuleBook lore
forall a. Semigroup a => a -> a -> a
<> RuleBook lore
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
loopRules RuleBook lore -> RuleBook lore -> RuleBook lore
forall a. Semigroup a => a -> a -> a
<> RuleBook lore
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
basicOpRules
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
(Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
where
isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip
simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM lore IndexResult
m
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certificates
cs' SubExp
se ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where
consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable
seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
ruleIf :: BinderOps lore => TopDownRuleIf lore
ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
| Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: Result
ses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
branch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> Result -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) Result
ses
]
where
checkBranch :: Maybe (BodyT lore)
checkBranch
| SubExp -> Bool
isCt1 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
| SubExp -> Bool
isCt0 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
| Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing
ruleIf
TopDown lore
_
Pattern lore
pat
StmAux (ExpDec lore)
_
( SubExp
cond,
Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
Body BodyDec lore
_ Stms lore
fstms [SubExp
se],
IfDec [BranchType lore]
ts IfSort
_
)
| Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms,
Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms,
[Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
| Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
(BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
ExpT lore
e <-
BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
BinOp
LogOr
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
( BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
BinOp
LogAnd
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
)
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
| [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: Result
ses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tbranch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> Result -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) Result
ses
]
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
| Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
then
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
else
if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
else Rule lore
forall lore. Rule lore
Skip
ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let tses :: Result
tses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tb
fses :: Result
fses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
fb
([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> ([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]))
-> [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$
((PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0 .. Int
num_ctx Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
(Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
(Result
tses', Result
fses') = [(SubExp, SubExp)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: Result
bodyResult = Result
tses'}
fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: Result
bodyResult = Result
fses'}
ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings
then do
BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
where
num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
bound_in_branches :: Names
bound_in_branches =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
(Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (Type -> Bool)
-> (PatElemT (LetDec lore) -> Type)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches
isMem :: TypeBase shape u -> Bool
isMem Mem {} = Bool
True
isMem TypeBase shape u
_ = Bool
False
sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes
branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
| SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| SubExp -> Bool
invariant SubExp
tse,
SubExp -> Bool
invariant SubExp
fse,
Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
Prim PrimType
_ <- PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec lore)
pe,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
[BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe]
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
tse]
RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
fse]
RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort)
)
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| Bool
otherwise =
Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
hoisted PatElemT dec
_ Right {} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing
reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
Result
ses <- BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body
let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) Result
ses
Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result -> m (BodyT (Lore m)))
-> (Result -> Result) -> Result -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> m (BodyT (Lore m))) -> m Result -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp) -> Result -> [ExtType] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult Result
val_ses [ExtType]
rets
reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array {} = do
Type
v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
let newshape :: Result
newshape = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> Result) -> Type -> Result
forall a b. (a -> b) -> a -> b
$ ExtType -> Type -> Type
removeExistentials ExtType
t Type
v_t
if Result
newshape Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
v_t
then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Result -> VName -> Exp (Lore m)
forall lore. Result -> VName -> Exp lore
shapeCoerce Result
newshape VName
v
else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
reshapeResult SubExp
se ExtType
_ =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
rettype IfSort
ifsort)
|
Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
[Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat,
Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
let tses :: Result
tses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tb
fses :: Result
fses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
fb
pick :: [a] -> [a]
pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
pat' :: [PatElemT (LetDec lore)]
pat' = [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfDec (BranchType lore) -> ExpT lore)
-> IfDec (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
rettype' IfSort
ifsort
| Bool
otherwise = Rule lore
forall lore. Rule lore
Skip
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False