{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.Mem.Simplify
  ( simplifyProgGeneric,
    simplifyStmsGeneric,
    simpleGeneric,
    SimplifyMemory,
  )
where

import Control.Monad
import Data.List (find)
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import qualified Futhark.IR.Syntax as AST
import qualified Futhark.Optimise.Simplify as Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (simplifiable)
import Futhark.Util

simpleGeneric ::
  (SimplifyMemory lore, Op lore ~ MemOp inner) =>
  (OpWithWisdom inner -> UT.UsageTable) ->
  Simplify.SimplifyOp lore inner ->
  Simplify.SimpleOps lore
simpleGeneric :: (OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
simpleGeneric = (OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
forall lore inner.
(SimplifiableLore lore, ExpDec lore ~ (), BodyDec lore ~ (),
 Op lore ~ MemOp inner, Allocator lore (PatAllocM lore)) =>
(OpWithWisdom inner -> UsageTable)
-> (inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore)))
-> SimpleOps lore
simplifiable

simplifyProgGeneric ::
  (SimplifyMemory lore, Op lore ~ MemOp inner) =>
  Simplify.SimpleOps lore ->
  Prog lore ->
  PassM (Prog lore)
simplifyProgGeneric :: SimpleOps lore -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric SimpleOps lore
ops =
  SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
forall lore.
SimplifiableLore lore =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
Simplify.simplifyProg
    SimpleOps lore
ops
    RuleBook (Wise lore)
forall lore. SimplifyMemory lore => RuleBook (Wise lore)
callKernelRules
    HoistBlockers lore
forall lore inner. (Op lore ~ MemOp inner) => HoistBlockers lore
blockers {blockHoistBranch :: BlockPred (Wise lore)
Engine.blockHoistBranch = BlockPred (Wise lore)
forall lore inner lore p.
(Typed (LetDec lore), Op lore ~ MemOp inner) =>
SymbolTable lore -> p -> Stm lore -> Bool
blockAllocs}
  where
    blockAllocs :: SymbolTable lore -> p -> Stm lore -> Bool
blockAllocs SymbolTable lore
vtable p
_ (Let Pattern lore
_ StmAux (ExpDec lore)
_ (Op Alloc {})) =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Bool
forall lore. SymbolTable lore -> Bool
ST.simplifyMemory SymbolTable lore
vtable
    -- Do not hoist statements that produce arrays.  This is
    -- because in the KernelsMem representation, multiple
    -- arrays can be located in the same memory block, and moving
    -- their creation out of a branch can thus cause memory
    -- corruption.  At this point in the compiler we have probably
    -- already moved all the array creations that matter.
    blockAllocs SymbolTable lore
_ p
_ (Let Pattern lore
pat StmAux (ExpDec lore)
_ ExpT lore
_) =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat

simplifyStmsGeneric ::
  ( HasScope lore m,
    MonadFreshNames m,
    SimplifyMemory lore,
    Op lore ~ MemOp inner
  ) =>
  Simplify.SimpleOps lore ->
  Stms lore ->
  m (ST.SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric :: SimpleOps lore
-> Stms lore -> m (SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric SimpleOps lore
ops Stms lore
stms = do
  Scope lore
scope <- m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
Simplify.simplifyStms
    SimpleOps lore
ops
    RuleBook (Wise lore)
forall lore. SimplifyMemory lore => RuleBook (Wise lore)
callKernelRules
    HoistBlockers lore
forall lore inner. (Op lore ~ MemOp inner) => HoistBlockers lore
blockers
    Scope lore
scope
    Stms lore
stms

isResultAlloc :: Op lore ~ MemOp op => Engine.BlockPred lore
isResultAlloc :: BlockPred lore
isResultAlloc SymbolTable lore
_ UsageTable
usage (Let (AST.Pattern [] [PatElemT (LetDec lore)
bindee]) StmAux (ExpDec lore)
_ (Op Alloc {})) =
  VName -> UsageTable -> Bool
UT.isInResult (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
bindee) UsageTable
usage
isResultAlloc SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False

isAlloc :: Op lore ~ MemOp op => Engine.BlockPred lore
isAlloc :: BlockPred lore
isAlloc SymbolTable lore
_ UsageTable
_ (Let Pattern lore
_ StmAux (ExpDec lore)
_ (Op Alloc {})) = Bool
True
isAlloc SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False

blockers ::
  (Op lore ~ MemOp inner) =>
  Simplify.HoistBlockers lore
blockers :: HoistBlockers lore
blockers =
  HoistBlockers lore
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers
    { blockHoistPar :: BlockPred (Wise lore)
Engine.blockHoistPar = BlockPred (Wise lore)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isAlloc,
      blockHoistSeq :: BlockPred (Wise lore)
Engine.blockHoistSeq = BlockPred (Wise lore)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isResultAlloc,
      isAllocation :: Stm (Wise lore) -> Bool
Engine.isAllocation = BlockPred (Wise lore)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isAlloc SymbolTable (Wise lore)
forall a. Monoid a => a
mempty UsageTable
forall a. Monoid a => a
mempty
    }

-- | Some constraints that must hold for the simplification rules to work.
type SimplifyMemory lore =
  ( Simplify.SimplifiableLore lore,
    ExpDec lore ~ (),
    BodyDec lore ~ (),
    AllocOp (Op (Wise lore)),
    CanBeWise (Op lore),
    BinderOps (Wise lore),
    Mem lore
  )

callKernelRules :: SimplifyMemory lore => RuleBook (Wise lore)
callKernelRules :: RuleBook (Wise lore)
callKernelRules =
  RuleBook (Wise lore)
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
standardRules
    RuleBook (Wise lore)
-> RuleBook (Wise lore) -> RuleBook (Wise lore)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise lore)]
-> [BottomUpRule (Wise lore)] -> RuleBook (Wise lore)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ RuleBasicOp (Wise lore) (TopDown (Wise lore))
-> TopDownRule (Wise lore)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp (Wise lore) (TopDown (Wise lore))
forall lore u.
(BinderOps lore, LetDec lore ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp lore
copyCopyToCopy,
        RuleBasicOp (Wise lore) (TopDown (Wise lore))
-> TopDownRule (Wise lore)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp (Wise lore) (TopDown (Wise lore))
forall lore u.
(BinderOps lore, LetDec lore ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp lore
removeIdentityCopy,
        RuleIf (Wise lore) (TopDown (Wise lore)) -> TopDownRule (Wise lore)
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf (Wise lore) (TopDown (Wise lore))
forall lore. SimplifyMemory lore => TopDownRuleIf (Wise lore)
unExistentialiseMemory,
        RuleOp (Wise lore) (TopDown (Wise lore)) -> TopDownRule (Wise lore)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise lore) (TopDown (Wise lore))
forall lore. SimplifyMemory lore => TopDownRuleOp (Wise lore)
decertifySafeAlloc
      ]
      []

-- | If a branch is returning some existential memory, but the size of
-- the array is not existential, and the index function of the array
-- does not refer to any names in the pattern, then we can create a
-- block of the proper size and always return there.
unExistentialiseMemory :: SimplifyMemory lore => TopDownRuleIf (Wise lore)
unExistentialiseMemory :: TopDownRuleIf (Wise lore)
unExistentialiseMemory TopDown (Wise lore)
vtable Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
_ (SubExp
cond, BodyT (Wise lore)
tbranch, BodyT (Wise lore)
fbranch, IfDec (BranchType (Wise lore))
ifdec)
  | TopDown (Wise lore) -> Bool
forall lore. SymbolTable lore -> Bool
ST.simplifyMemory TopDown (Wise lore)
vtable,
    [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable <- ([(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
   PrimExp VName, VName, Space)]
 -> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
      PrimExp VName, VName, Space)])
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
hasConcretisableMemory [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
forall a. Monoid a => a
mempty ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
      PrimExp VName, VName, Space)])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
-> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable = RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ do
    -- Create non-existential memory blocks big enough to hold the
    -- arrays.
    ([(VName, VName)]
arr_to_mem, [(VName, VName)]
oldmem_to_mem) <-
      ([((VName, VName), (VName, VName))]
 -> ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise lore) [((VName, VName), (VName, VName))]
-> RuleM (Wise lore) ([(VName, VName)], [(VName, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((VName, VName), (VName, VName))]
-> ([(VName, VName)], [(VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip (RuleM (Wise lore) [((VName, VName), (VName, VName))]
 -> RuleM (Wise lore) ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise lore) [((VName, VName), (VName, VName))]
-> RuleM (Wise lore) ([(VName, VName)], [(VName, VName)])
forall a b. (a -> b) -> a -> b
$
        [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
-> ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)
    -> RuleM (Wise lore) ((VName, VName), (VName, VName)))
-> RuleM (Wise lore) [((VName, VName), (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable (((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
   PrimExp VName, VName, Space)
  -> RuleM (Wise lore) ((VName, VName), (VName, VName)))
 -> RuleM (Wise lore) [((VName, VName), (VName, VName))])
-> ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)
    -> RuleM (Wise lore) ((VName, VName), (VName, VName)))
-> RuleM (Wise lore) [((VName, VName), (VName, VName))]
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
arr_pe, PrimExp VName
mem_size, VName
oldmem, Space
space) -> do
          SubExp
size <- String -> PrimExp VName -> RuleM (Wise lore) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"size" PrimExp VName
mem_size
          VName
mem <- String -> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"mem" (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) VName)
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) VName
forall a b. (a -> b) -> a -> b
$ Op (Wise lore) -> ExpT (Wise lore)
forall lore. Op lore -> ExpT lore
Op (Op (Wise lore) -> ExpT (Wise lore))
-> Op (Wise lore) -> ExpT (Wise lore)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> OpWithWisdom (Op lore)
forall op. AllocOp op => SubExp -> Space -> op
allocOp SubExp
size Space
space
          ((VName, VName), (VName, VName))
-> RuleM (Wise lore) ((VName, VName), (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
arr_pe, VName
mem), (VName
oldmem, VName
mem))

    -- Update the branches to contain Copy expressions putting the
    -- arrays where they are expected.
    let updateBody :: BodyT (Wise lore)
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
updateBody BodyT (Wise lore)
body = RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
 -> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore)))))
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
forall a b. (a -> b) -> a -> b
$ do
          [SubExp]
res <- Body (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind Body (Lore (RuleM (Wise lore)))
BodyT (Wise lore)
body
          [SubExp] -> RuleM (Wise lore) (BodyT (Wise lore))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM
            ([SubExp] -> RuleM (Wise lore) (BodyT (Wise lore)))
-> RuleM (Wise lore) [SubExp]
-> RuleM (Wise lore) (BodyT (Wise lore))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> SubExp -> RuleM (Wise lore) SubExp)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> RuleM (Wise lore) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> RuleM (Wise lore) SubExp
updateResult (PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat) [SubExp]
res
        updateResult :: PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> RuleM (Wise lore) SubExp
updateResult PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem (Var VName
v)
          | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem) [(VName, VName)]
arr_to_mem,
            (VarWisdom
_, MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
_ IxFun
ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem = do
            VName
v_copy <- String -> RuleM (Wise lore) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise lore) VName)
-> String -> RuleM (Wise lore) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_nonext_copy"
            let v_pat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
v_pat =
                  [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
                    []
                    [ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v_copy (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
                        PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun
                    ]
            Stm (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Stm (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> StmAux (ExpDec lore) -> ExpT (Wise lore) -> Stm (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm Pattern lore
PatternT (MemInfo SubExp NoUniqueness MemBind)
v_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT (Wise lore) -> Stm (Wise lore))
-> ExpT (Wise lore) -> Stm (Wise lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (VName -> BasicOp
Copy VName
v)
            SubExp -> RuleM (Wise lore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> RuleM (Wise lore) SubExp)
-> SubExp -> RuleM (Wise lore) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_copy
          | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem) [(VName, VName)]
oldmem_to_mem =
            SubExp -> RuleM (Wise lore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> RuleM (Wise lore) SubExp)
-> SubExp -> RuleM (Wise lore) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem
        updateResult PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
_ SubExp
se =
          SubExp -> RuleM (Wise lore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
    BodyT (Wise lore)
tbranch' <- BodyT (Wise lore)
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
updateBody BodyT (Wise lore)
tbranch
    BodyT (Wise lore)
fbranch' <- BodyT (Wise lore)
-> RuleM (Wise lore) (Body (Lore (RuleM (Wise lore))))
updateBody BodyT (Wise lore)
fbranch
    Pattern (Lore (RuleM (Wise lore)))
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise lore)))
Pattern (Wise lore)
pat (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Wise lore)
-> BodyT (Wise lore)
-> IfDec (BranchType (Wise lore))
-> ExpT (Wise lore)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT (Wise lore)
tbranch' BodyT (Wise lore)
fbranch' IfDec (BranchType (Wise lore))
ifdec
  where
    onlyUsedIn :: VName -> VName -> Bool
onlyUsedIn VName
name VName
here =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
        (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
name VName -> Names -> Bool
`nameIn`) (Names -> Bool)
-> (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
    -> Names)
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> Bool
forall a b. (a -> b) -> a -> b
$
          (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
here) (VName -> Bool)
-> (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
    -> VName)
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName) ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$
            PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat
    knownSize :: SubExp -> Bool
knownSize Constant {} = Bool
True
    knownSize (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
inContext VName
v
    inContext :: VName -> Bool
inContext = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat)

    hasConcretisableMemory :: [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
hasConcretisableMemory [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem
      | (VarWisdom
_, MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem,
        Just (Int
j, Mem Space
space) <-
          (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> TypeBase Shape NoUniqueness)
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> (Int, TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType
            ((Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
 -> (Int, TypeBase Shape NoUniqueness))
-> Maybe
     (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> Maybe (Int, TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
 -> Bool)
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
-> Maybe
     (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
              ((VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> ((Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
    -> VName)
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> VName)
-> ((Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
    -> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall a b. (a, b) -> b
snd)
              ([Int]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
0 :: Int) ..] ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [(Int,
      PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat),
        Just SubExp
tse <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT (Wise lore) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT (Wise lore)
tbranch,
        Just SubExp
fse <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT (Wise lore) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT (Wise lore)
fbranch,
        VName
mem VName -> VName -> Bool
`onlyUsedIn` PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem,
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
knownSize (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape),
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> Names
forall a. FreeIn a => a -> Names
freeIn IxFun
ixfun Names -> Names -> Bool
`namesIntersect` [VName] -> Names
namesFromList (PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat),
        SubExp
fse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SubExp
tse =
        let mem_size :: PrimExp VName
mem_size =
              TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (IxFun -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)
         in (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem, PrimExp VName
mem_size, VName
mem, Space
space) (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
 PrimExp VName, VName, Space)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     PrimExp VName, VName, Space)]
forall a. a -> [a] -> [a]
: [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable
      | Bool
otherwise =
        [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
  PrimExp VName, VName, Space)]
fixable
unExistentialiseMemory TopDown (Wise lore)
_ Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ (SubExp, BodyT (Wise lore), BodyT (Wise lore),
 IfDec (BranchType (Wise lore)))
_ = Rule (Wise lore)
forall lore. Rule lore
Skip

-- | If we are copying something that is itself a copy, just copy the
-- original one instead.
copyCopyToCopy ::
  ( BinderOps lore,
    LetDec lore ~ (VarWisdom, MemBound u)
  ) =>
  TopDownRuleBasicOp lore
copyCopyToCopy :: TopDownRuleBasicOp lore
copyCopyToCopy TopDown lore
vtable pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pat_elem]) StmAux (ExpDec lore)
_ (Copy VName
v1)
  | Just (BasicOp (Copy VName
v2), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,
    Just (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
srcmem IxFun
src_ixfun)) <-
      Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall lore. Entry lore -> Maybe (LetDec lore)
ST.entryLetBoundDec (Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind))
-> Maybe (Entry lore)
-> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v1 TopDown lore
vtable,
    Just (Mem Space
src_space) <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
srcmem TopDown lore
vtable,
    (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
destmem IxFun
dest_ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp u MemBind)
-> (VarWisdom, MemInfo SubExp u MemBind)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarWisdom, MemInfo SubExp u MemBind)
PatElemT (LetDec lore)
pat_elem,
    Just (Mem Space
dest_space) <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
destmem TopDown lore
vtable,
    Space
src_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
dest_space,
    IxFun
dest_ixfun IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
src_ixfun =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v2
copyCopyToCopy TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Copy VName
v0)
  | Just (BasicOp (Rearrange [Int]
perm VName
v1), Certificates
v0_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v0 TopDown lore
vtable,
    Just (BasicOp (Copy VName
v2), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    VName
v0' <-
      Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v0_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v1_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_v0" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v2
    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v0'
copyCopyToCopy TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

-- | If the destination of a copy is the same as the source, just
-- remove it.
removeIdentityCopy ::
  ( BinderOps lore,
    LetDec lore ~ (VarWisdom, MemBound u)
  ) =>
  TopDownRuleBasicOp lore
removeIdentityCopy :: TopDownRuleBasicOp lore
removeIdentityCopy TopDown lore
vtable pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (Copy VName
v)
  | (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
dest_mem IxFun
dest_ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp u MemBind)
-> (VarWisdom, MemInfo SubExp u MemBind)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarWisdom, MemInfo SubExp u MemBind)
PatElemT (LetDec lore)
pe,
    Just (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
src_mem IxFun
src_ixfun)) <-
      Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall lore. Entry lore -> Maybe (LetDec lore)
ST.entryLetBoundDec (Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind))
-> Maybe (Entry lore)
-> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v TopDown lore
vtable,
    VName
dest_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_mem,
    IxFun
dest_ixfun IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
src_ixfun =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
removeIdentityCopy TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

-- If an allocation is statically known to be safe, then we can remove
-- the certificates on it.  This can help hoist things that would
-- otherwise be stuck inside loops or branches.
decertifySafeAlloc :: SimplifyMemory lore => TopDownRuleOp (Wise lore)
decertifySafeAlloc :: TopDownRuleOp (Wise lore)
decertifySafeAlloc TopDown (Wise lore)
_ Pattern (Wise lore)
pat (StmAux Certificates
cs Attrs
attrs ExpDec (Wise lore)
_) Op (Wise lore)
op
  | Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
/= Certificates
forall a. Monoid a => a
mempty,
    [Mem Space
_] <- PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise lore)
pat,
    OpWithWisdom (Op lore) -> Bool
forall op. IsOp op => op -> Bool
safeOp Op (Wise lore)
OpWithWisdom (Op lore)
op =
    RuleM (Wise lore) () -> Rule (Wise lore)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise lore) () -> Rule (Wise lore))
-> RuleM (Wise lore) () -> Rule (Wise lore)
forall a b. (a -> b) -> a -> b
$ Attrs -> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM (Wise lore) () -> RuleM (Wise lore) ())
-> RuleM (Wise lore) () -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise lore)))
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (RuleM (Wise lore)))
Pattern (Wise lore)
pat (Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ())
-> Exp (Lore (RuleM (Wise lore))) -> RuleM (Wise lore) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise lore) -> ExpT (Wise lore)
forall lore. Op lore -> ExpT lore
Op Op (Wise lore)
op
decertifySafeAlloc TopDown (Wise lore)
_ Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ Op (Wise lore)
_ = Rule (Wise lore)
forall lore. Rule lore
Skip