{-# LANGUAGE FlexibleContexts #-}
module Futhark.Optimise.Simplify.Rules.ClosedForm
( foldClosedForm,
loopClosedForm,
)
where
import Control.Monad
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename
foldClosedForm ::
(ASTLore lore, BinderOps lore) =>
VarLookup lore ->
Pattern lore ->
Lambda lore ->
[SubExp] ->
[VName] ->
RuleM lore ()
foldClosedForm :: VarLookup lore
-> Pattern lore
-> Lambda lore
-> [SubExp]
-> [VName]
-> RuleM lore ()
foldClosedForm VarLookup lore
look Pattern lore
pat Lambda lore
lam [SubExp]
accs [VName]
arrs = do
SubExp
inputsize <- Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([TypeBase Shape NoUniqueness] -> SubExp)
-> RuleM lore [TypeBase Shape NoUniqueness] -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> RuleM lore (TypeBase Shape NoUniqueness))
-> [VName] -> RuleM lore [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
arrs
PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of
[Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
[TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify
Body lore
closedBody <-
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults
(Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat)
SubExp
inputsize
Names
forall a. Monoid a => a
mempty
IntType
Int64
Map VName SubExp
knownBnds
((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam))
(Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
[SubExp]
accs
VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
(Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
accs
RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
)
where
knownBnds :: Map VName SubExp
knownBnds = VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
forall lore.
VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs
loopClosedForm ::
(ASTLore lore, BinderOps lore) =>
Pattern lore ->
[(FParam lore, SubExp)] ->
Names ->
IntType ->
SubExp ->
Body lore ->
RuleM lore ()
loopClosedForm :: Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
merge Names
i IntType
it SubExp
bound Body lore
body = do
PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of
[Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
[TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify
Body lore
closedBody <-
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults
[VName]
mergenames
SubExp
bound
Names
i
IntType
it
Map VName SubExp
knownBnds
((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
Body lore
body
[SubExp]
mergeexp
VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
(Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
mergeexp
RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
)
where
([FParam lore]
mergepat, [SubExp]
mergeexp) = [(FParam lore, SubExp)] -> ([FParam lore], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
merge
mergeidents :: [Ident]
mergeidents = (FParam lore -> Ident) -> [FParam lore] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam lore]
mergepat
mergenames :: [VName]
mergenames = (FParam lore -> VName) -> [FParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> VName
forall dec. Param dec -> VName
paramName [FParam lore]
mergepat
knownBnds :: Map VName SubExp
knownBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp
checkResults ::
BinderOps lore =>
[VName] ->
SubExp ->
Names ->
IntType ->
M.Map VName SubExp ->
[VName] ->
Body lore ->
[SubExp] ->
RuleM lore (Body lore)
checkResults :: [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body lore
body [SubExp]
accs = do
((), Stms lore
bnds) <-
RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> (VName, SubExp) -> RuleM lore ())
-> [(VName, SubExp)] -> [(VName, SubExp)] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
forall lore.
BinderOps lore =>
(VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat [SubExp]
res) ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
Stms (Lore (RuleM lore))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms lore
Stms (Lore (RuleM lore))
bnds ([SubExp] -> RuleM lore (Body (Lore (RuleM lore))))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
pat
where
bndMap :: Map VName (Exp lore)
bndMap = Body lore -> Map VName (Exp lore)
forall lore. Body lore -> Map VName (Exp lore)
makeBindMap Body lore
body
([VName]
accparams, [VName]
_) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
res :: [SubExp]
res = Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body
nonFree :: Names
nonFree = Body lore -> Names
forall lore. Body lore -> Names
boundInBody Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
untouchable
checkResult :: (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult (VName
p, Var VName
v) (VName
accparam, SubExp
acc)
| Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- VName -> Map VName (Exp lore) -> Maybe (Exp lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp lore)
bndMap = do
let isThisAccum :: SubExp -> Bool
isThisAccum = (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
(SubExp
this, SubExp
el) <- Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a lore. Maybe a -> RuleM lore a
liftMaybe (Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp))
-> Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
(SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
) of
((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> Maybe (SubExp, SubExp)
forall a. Maybe a
Nothing
case BinOp
bop of
BinOp
LogAnd ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
Add IntType
t Overflow
w | Just RuleM lore SubExp
properly_typed_size <- IntType -> Maybe (RuleM lore SubExp)
forall (m :: * -> *). MonadBinder m => IntType -> Maybe (m SubExp)
properIntSize IntType
t -> do
SubExp
size' <- RuleM lore SubExp
properly_typed_size
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p]
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
(IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
(SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
FAdd FloatType
t | Just RuleM lore SubExp
properly_typed_size <- FloatType -> Maybe (RuleM lore SubExp)
forall (m :: * -> *).
MonadBinder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
SubExp
size' <- RuleM lore SubExp
properly_typed_size
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p]
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
(FloatType -> BinOp
FAdd FloatType
t)
(SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
BinOp
_ -> RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
checkResult (VName, SubExp)
_ (VName, SubExp)
_ = RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
| VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
asFreeSubExp SubExp
se = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
properIntSize :: IntType -> Maybe (m SubExp)
properIntSize IntType
Int64 = m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$ SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size
properIntSize IntType
t =
m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it IntType
t) SubExp
size
properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size
determineKnownBindings ::
VarLookup lore ->
Lambda lore ->
[SubExp] ->
[VName] ->
M.Map VName SubExp
determineKnownBindings :: VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs =
Map VName SubExp
accBnds Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
where
([Param (LParamInfo lore)]
accparams, [Param (LParamInfo lore)]
arrparams) =
Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
accBnds :: Map VName SubExp
accBnds =
[(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
accparams) [SubExp]
accs
arrBnds :: Map VName SubExp
arrBnds =
[(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
((VName, VName) -> Maybe (VName, SubExp))
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, VName) -> Maybe (VName, SubExp)
forall a. (a, VName) -> Maybe (a, SubExp)
isReplicate ([(VName, VName)] -> [(VName, SubExp)])
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
arrparams) [VName]
arrs
isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
| Just (BasicOp (Replicate Shape
_ SubExp
ve), Certificates
cs) <- VarLookup lore
look VName
v,
Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty =
(a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
isReplicate (a, VName)
_ = Maybe (a, SubExp)
forall a. Maybe a
Nothing
makeBindMap :: Body lore -> M.Map VName (Exp lore)
makeBindMap :: Body lore -> Map VName (Exp lore)
makeBindMap = [(VName, Exp lore)] -> Map VName (Exp lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Exp lore)] -> Map VName (Exp lore))
-> (Body lore -> [(VName, Exp lore)])
-> Body lore
-> Map VName (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm lore -> Maybe (VName, Exp lore))
-> [Stm lore] -> [(VName, Exp lore)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm lore -> Maybe (VName, Exp lore)
forall lore. Stm lore -> Maybe (VName, Exp lore)
isSingletonStm ([Stm lore] -> [(VName, Exp lore)])
-> (Body lore -> [Stm lore]) -> Body lore -> [(VName, Exp lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore])
-> (Body lore -> Stms lore) -> Body lore -> [Stm lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms
where
isSingletonStm :: Stm lore -> Maybe (VName, Exp lore)
isSingletonStm (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) = case Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat of
[VName
v] -> (VName, Exp lore) -> Maybe (VName, Exp lore)
forall a. a -> Maybe a
Just (VName
v, Exp lore
e)
[VName]
_ -> Maybe (VName, Exp lore)
forall a. Maybe a
Nothing