-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Optimise.ArrayLayout.Transform
  ( Transform,
    transformStms,
  )
where

import Control.Monad
import Control.Monad.State.Strict
import Data.Map.Strict qualified as M
import Futhark.Analysis.AccessPattern (IndexExprName, SegOpName (..))
import Futhark.Analysis.PrimExp.Table (PrimExpAnalysis)
import Futhark.Builder
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.Optimise.ArrayLayout.Layout (Layout, LayoutTable, Permutation)

class (Layout rep, PrimExpAnalysis rep) => Transform rep where
  onOp ::
    (Monad m) =>
    SOACMapper rep rep m ->
    Op rep ->
    m (Op rep)
  transformOp ::
    LayoutTable ->
    ExpMap rep ->
    Stm rep ->
    Op rep ->
    TransformM rep (LayoutTable, ExpMap rep)

type TransformM rep = Builder rep

-- | A map from the name of an expression to the expression that defines it.
type ExpMap rep = M.Map VName (Stm rep)

instance Transform GPU where
  onOp :: forall (m :: * -> *).
Monad m =>
SOACMapper GPU GPU m -> Op GPU -> m (Op GPU)
onOp SOACMapper GPU GPU m
soac_mapper (Futhark.IR.GPU.OtherOp SOAC GPU
soac) =
    SOAC GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. op rep -> HostOp op rep
Futhark.IR.GPU.OtherOp (SOAC GPU -> HostOp SOAC GPU)
-> m (SOAC GPU) -> m (HostOp SOAC GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper GPU GPU m -> SOAC GPU -> m (SOAC GPU)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper GPU GPU m
soac_mapper SOAC GPU
soac
  onOp SOACMapper GPU GPU m
_ Op GPU
op = HostOp SOAC GPU -> m (HostOp SOAC GPU)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op GPU
HostOp SOAC GPU
op
  transformOp :: LayoutTable
-> ExpMap GPU
-> Stm GPU
-> Op GPU
-> TransformM GPU (LayoutTable, ExpMap GPU)
transformOp LayoutTable
perm_table ExpMap GPU
expmap Stm GPU
stm Op GPU
gpuOp
    | SegOp SegOp SegLevel GPU
op <- Op GPU
gpuOp,
      -- TODO: handle non-segthread cases. This requires some care to
      -- avoid doing huge manifests at the block level.
      SegThread {} <- SegOp SegLevel GPU -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op =
        LayoutTable
-> ExpMap GPU
-> Stm GPU
-> SegOp SegLevel GPU
-> TransformM GPU (LayoutTable, ExpMap GPU)
transformSegOpGPU LayoutTable
perm_table ExpMap GPU
expmap Stm GPU
stm SegOp SegLevel GPU
op
    | Op GPU
_ <- Op GPU
gpuOp = LayoutTable
-> ExpMap GPU
-> Stm GPU
-> TransformM GPU (LayoutTable, ExpMap GPU)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep
-> Stm rep
-> TransformM rep (LayoutTable, ExpMap rep)
transformRestOp LayoutTable
perm_table ExpMap GPU
expmap Stm GPU
stm

instance Transform MC where
  onOp :: forall (m :: * -> *).
Monad m =>
SOACMapper MC MC m -> Op MC -> m (Op MC)
onOp SOACMapper MC MC m
soac_mapper (Futhark.IR.MC.OtherOp SOAC MC
soac) =
    SOAC MC -> MCOp SOAC MC
forall (op :: * -> *) rep. op rep -> MCOp op rep
Futhark.IR.MC.OtherOp (SOAC MC -> MCOp SOAC MC) -> m (SOAC MC) -> m (MCOp SOAC MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper MC MC m -> SOAC MC -> m (SOAC MC)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper MC MC m
soac_mapper SOAC MC
soac
  onOp SOACMapper MC MC m
_ Op MC
op = MCOp SOAC MC -> m (MCOp SOAC MC)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op MC
MCOp SOAC MC
op
  transformOp :: LayoutTable
-> ExpMap MC
-> Stm MC
-> Op MC
-> TransformM MC (LayoutTable, ExpMap MC)
transformOp LayoutTable
perm_table ExpMap MC
expmap Stm MC
stm Op MC
mcOp
    | ParOp Maybe (SegOp () MC)
maybe_par_segop SegOp () MC
seqSegOp <- Op MC
mcOp =
        LayoutTable
-> ExpMap MC
-> Stm MC
-> Maybe (SegOp () MC)
-> SegOp () MC
-> TransformM MC (LayoutTable, ExpMap MC)
transformSegOpMC LayoutTable
perm_table ExpMap MC
expmap Stm MC
stm Maybe (SegOp () MC)
maybe_par_segop SegOp () MC
seqSegOp
    | Op MC
_ <- Op MC
mcOp = LayoutTable
-> ExpMap MC -> Stm MC -> TransformM MC (LayoutTable, ExpMap MC)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep
-> Stm rep
-> TransformM rep (LayoutTable, ExpMap rep)
transformRestOp LayoutTable
perm_table ExpMap MC
expmap Stm MC
stm

transformSegOpGPU :: LayoutTable -> ExpMap GPU -> Stm GPU -> SegOp SegLevel GPU -> TransformM GPU (LayoutTable, ExpMap GPU)
transformSegOpGPU :: LayoutTable
-> ExpMap GPU
-> Stm GPU
-> SegOp SegLevel GPU
-> TransformM GPU (LayoutTable, ExpMap GPU)
transformSegOpGPU LayoutTable
perm_table ExpMap GPU
expmap stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
_) SegOp SegLevel GPU
op =
  -- Optimization: Only traverse the body of the SegOp if it is
  -- represented in the layout table
  case VName
-> Map VName (Map ArrayName (Map VName Permutation))
-> Maybe (Map ArrayName (Map VName Permutation))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
patternName ((SegOpName -> VName)
-> LayoutTable -> Map VName (Map ArrayName (Map VName Permutation))
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys SegOpName -> VName
vnameFromSegOp LayoutTable
perm_table) of
    Maybe (Map ArrayName (Map VName Permutation))
Nothing -> do
      Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
stm
      (LayoutTable, ExpMap GPU)
-> TransformM GPU (LayoutTable, ExpMap GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LayoutTable
perm_table, [(VName, Stm GPU)] -> ExpMap GPU
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm) | VName
name <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat] ExpMap GPU -> ExpMap GPU -> ExpMap GPU
forall a. Semigroup a => a -> a -> a
<> ExpMap GPU
expmap)
    Just Map ArrayName (Map VName Permutation)
_ -> do
      let mapper :: SegOpMapper lvl GPU GPU (BuilderT GPU (State VNameSource))
mapper =
            SegOpMapper lvl GPU GPU (BuilderT GPU (State VNameSource))
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
              { mapOnSegOpBody = case segLevel op of
                  SegBlock {} -> LayoutTable
-> ExpMap GPU
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep -> KernelBody rep -> TransformM rep (KernelBody rep)
transformSegGroupKernelBody LayoutTable
perm_table ExpMap GPU
expmap
                  SegLevel
_ -> LayoutTable
-> VName
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> VName -> KernelBody rep -> TransformM rep (KernelBody rep)
transformSegThreadKernelBody LayoutTable
perm_table VName
patternName
              }
      SegOp SegLevel GPU
op' <- SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
-> SegOp SegLevel GPU
-> BuilderT GPU (State VNameSource) (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
forall {lvl}.
SegOpMapper lvl GPU GPU (BuilderT GPU (State VNameSource))
mapper SegOp SegLevel GPU
op
      let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp SegOp SegLevel GPU
op'
      Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
stm'
      (LayoutTable, ExpMap GPU)
-> TransformM GPU (LayoutTable, ExpMap GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LayoutTable
perm_table, [(VName, Stm GPU)] -> ExpMap GPU
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm') | VName
name <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat] ExpMap GPU -> ExpMap GPU -> ExpMap GPU
forall a. Semigroup a => a -> a -> a
<> ExpMap GPU
expmap)
  where
    patternName :: VName
patternName = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName)
-> ([PatElem Type] -> PatElem Type) -> [PatElem Type] -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PatElem Type] -> PatElem Type
forall a. HasCallStack => [a] -> a
head ([PatElem Type] -> VName) -> [PatElem Type] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec GPU)
pat

transformSegOpMC :: LayoutTable -> ExpMap MC -> Stm MC -> Maybe (SegOp () MC) -> SegOp () MC -> TransformM MC (LayoutTable, ExpMap MC)
transformSegOpMC :: LayoutTable
-> ExpMap MC
-> Stm MC
-> Maybe (SegOp () MC)
-> SegOp () MC
-> TransformM MC (LayoutTable, ExpMap MC)
transformSegOpMC LayoutTable
perm_table ExpMap MC
expmap (Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux Exp MC
_) Maybe (SegOp () MC)
maybe_par_segop SegOp () MC
seqSegOp
  | Maybe (SegOp () MC)
Nothing <- Maybe (SegOp () MC)
maybe_par_segop = Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC)
add Maybe (SegOp () MC)
forall a. Maybe a
Nothing
  | Just SegOp () MC
par_segop <- Maybe (SegOp () MC)
maybe_par_segop =
      -- Optimization: Only traverse the body of the SegOp if it is
      -- represented in the layout table
      case VName
-> Map VName (Map ArrayName (Map VName Permutation))
-> Maybe (Map ArrayName (Map VName Permutation))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
patternName ((SegOpName -> VName)
-> LayoutTable -> Map VName (Map ArrayName (Map VName Permutation))
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys SegOpName -> VName
vnameFromSegOp LayoutTable
perm_table) of
        Maybe (Map ArrayName (Map VName Permutation))
Nothing -> Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC)
add (Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC))
-> Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC)
forall a b. (a -> b) -> a -> b
$ SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_segop
        Just Map ArrayName (Map VName Permutation)
_ -> Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC)
add (Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC))
-> (SegOp () MC -> Maybe (SegOp () MC))
-> SegOp () MC
-> TransformM MC (LayoutTable, ExpMap MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just (SegOp () MC -> TransformM MC (LayoutTable, ExpMap MC))
-> BuilderT MC (State VNameSource) (SegOp () MC)
-> TransformM MC (LayoutTable, ExpMap MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpMapper () MC MC (BuilderT MC (State VNameSource))
-> SegOp () MC -> BuilderT MC (State VNameSource) (SegOp () MC)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MC MC (BuilderT MC (State VNameSource))
forall {lvl}.
SegOpMapper lvl MC MC (BuilderT MC (State VNameSource))
mapper SegOp () MC
par_segop
  where
    add :: Maybe (SegOp () MC) -> TransformM MC (LayoutTable, ExpMap MC)
add Maybe (SegOp () MC)
maybe_par_segop' = do
      -- Map the sequential part of the ParOp
      SegOp () MC
seqSegOp' <- SegOpMapper () MC MC (BuilderT MC (State VNameSource))
-> SegOp () MC -> BuilderT MC (State VNameSource) (SegOp () MC)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MC MC (BuilderT MC (State VNameSource))
forall {lvl}.
SegOpMapper lvl MC MC (BuilderT MC (State VNameSource))
mapper SegOp () MC
seqSegOp
      let stm' :: Stm MC
stm' = Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
maybe_par_segop' SegOp () MC
seqSegOp'
      Stm (Rep (BuilderT MC (State VNameSource)))
-> BuilderT MC (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT MC (State VNameSource)))
Stm MC
stm'
      (LayoutTable, ExpMap MC) -> TransformM MC (LayoutTable, ExpMap MC)
forall a. a -> BuilderT MC (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LayoutTable
perm_table, [(VName, Stm MC)] -> ExpMap MC
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm MC
stm') | VName
name <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec MC)
pat] ExpMap MC -> ExpMap MC -> ExpMap MC
forall a. Semigroup a => a -> a -> a
<> ExpMap MC
expmap)
    mapper :: SegOpMapper lvl MC MC (BuilderT MC (State VNameSource))
mapper = SegOpMapper lvl MC MC (BuilderT MC (State VNameSource))
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper {mapOnSegOpBody = transformKernelBody perm_table expmap patternName}
    patternName :: VName
patternName = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName)
-> ([PatElem Type] -> PatElem Type) -> [PatElem Type] -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PatElem Type] -> PatElem Type
forall a. HasCallStack => [a] -> a
head ([PatElem Type] -> VName) -> [PatElem Type] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec MC)
pat

transformRestOp :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Stm rep -> TransformM rep (LayoutTable, ExpMap rep)
transformRestOp :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep
-> Stm rep
-> TransformM rep (LayoutTable, ExpMap rep)
transformRestOp LayoutTable
perm_table ExpMap rep
expmap (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
  Exp rep
e' <- Mapper rep rep (BuilderT rep (State VNameSource))
-> Exp rep -> BuilderT rep (State VNameSource) (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (LayoutTable
-> ExpMap rep -> Mapper rep rep (BuilderT rep (State VNameSource))
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Mapper rep rep (TransformM rep)
transform LayoutTable
perm_table ExpMap rep
expmap) Exp rep
e
  let stm' :: Stm rep
stm' = Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e'
  Stm (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm rep
Stm (Rep (BuilderT rep (State VNameSource)))
stm'
  (LayoutTable, ExpMap rep)
-> TransformM rep (LayoutTable, ExpMap rep)
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LayoutTable
perm_table, [(VName, Stm rep)] -> ExpMap rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm rep
stm') | VName
name <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat] ExpMap rep -> ExpMap rep -> ExpMap rep
forall a. Semigroup a => a -> a -> a
<> ExpMap rep
expmap)

transform :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Mapper rep rep (TransformM rep)
transform :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Mapper rep rep (TransformM rep)
transform LayoutTable
perm_table ExpMap rep
expmap =
  Mapper rep rep (BuilderT rep (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper {mapOnBody = \Scope rep
scope -> Scope rep
-> BuilderT rep (State VNameSource) (Body rep)
-> BuilderT rep (State VNameSource) (Body rep)
forall a.
Scope rep
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (BuilderT rep (State VNameSource) (Body rep)
 -> BuilderT rep (State VNameSource) (Body rep))
-> (Body rep -> BuilderT rep (State VNameSource) (Body rep))
-> Body rep
-> BuilderT rep (State VNameSource) (Body rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayoutTable
-> ExpMap rep
-> Body rep
-> BuilderT rep (State VNameSource) (Body rep)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Body rep -> TransformM rep (Body rep)
transformBody LayoutTable
perm_table ExpMap rep
expmap}

-- | Recursively transform the statements in a body.
transformBody :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Body rep -> TransformM rep (Body rep)
transformBody :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Body rep -> TransformM rep (Body rep)
transformBody LayoutTable
perm_table ExpMap rep
expmap (Body BodyDec rep
b Stms rep
stms Result
res) =
  BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
b (Stms rep -> Result -> Body rep)
-> BuilderT rep (State VNameSource) (Stms rep)
-> BuilderT rep (State VNameSource) (Result -> Body rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LayoutTable
-> ExpMap rep
-> Stms rep
-> BuilderT rep (State VNameSource) (Stms rep)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
perm_table ExpMap rep
expmap Stms rep
stms BuilderT rep (State VNameSource) (Result -> Body rep)
-> BuilderT rep (State VNameSource) Result
-> BuilderT rep (State VNameSource) (Body rep)
forall a b.
BuilderT rep (State VNameSource) (a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> BuilderT rep (State VNameSource) Result
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | Recursively transform the statements in the body of a SegGroup kernel.
transformSegGroupKernelBody ::
  (Transform rep, BuilderOps rep) =>
  LayoutTable ->
  ExpMap rep ->
  KernelBody rep ->
  TransformM rep (KernelBody rep)
transformSegGroupKernelBody :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep -> KernelBody rep -> TransformM rep (KernelBody rep)
transformSegGroupKernelBody LayoutTable
perm_table ExpMap rep
expmap (KernelBody BodyDec rep
b Stms rep
stms [KernelResult]
res) =
  BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
b (Stms rep -> [KernelResult] -> KernelBody rep)
-> BuilderT rep (State VNameSource) (Stms rep)
-> BuilderT
     rep (State VNameSource) ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LayoutTable
-> ExpMap rep
-> Stms rep
-> BuilderT rep (State VNameSource) (Stms rep)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
perm_table ExpMap rep
expmap Stms rep
stms BuilderT rep (State VNameSource) ([KernelResult] -> KernelBody rep)
-> BuilderT rep (State VNameSource) [KernelResult]
-> BuilderT rep (State VNameSource) (KernelBody rep)
forall a b.
BuilderT rep (State VNameSource) (a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> BuilderT rep (State VNameSource) [KernelResult]
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

-- | Transform the statements in the body of a SegThread kernel.
transformSegThreadKernelBody ::
  (Transform rep, BuilderOps rep) =>
  LayoutTable ->
  VName ->
  KernelBody rep ->
  TransformM rep (KernelBody rep)
transformSegThreadKernelBody :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> VName -> KernelBody rep -> TransformM rep (KernelBody rep)
transformSegThreadKernelBody LayoutTable
perm_table VName
seg_name KernelBody rep
kbody = do
  StateT
  Replacements (BuilderT rep (State VNameSource)) (KernelBody rep)
-> Replacements -> TransformM rep (KernelBody rep)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ( VName
-> ArrayIndexTransform
     (StateT Replacements (BuilderT rep (State VNameSource)))
-> KernelBody rep
-> StateT
     Replacements (BuilderT rep (State VNameSource)) (KernelBody rep)
forall (m :: * -> *) rep.
(Monad m, Transform rep) =>
VName
-> ArrayIndexTransform m -> KernelBody rep -> m (KernelBody rep)
traverseKernelBodyArrayIndexes
        VName
seg_name
        (LayoutTable
-> ArrayIndexTransform
     (StateT Replacements (BuilderT rep (State VNameSource)))
forall (m :: * -> *).
MonadBuilder m =>
LayoutTable -> ArrayIndexTransform (StateT Replacements m)
ensureTransformedAccess LayoutTable
perm_table)
        KernelBody rep
kbody
    )
    Replacements
forall a. Monoid a => a
mempty

transformKernelBody ::
  (Transform rep, BuilderOps rep) =>
  LayoutTable ->
  ExpMap rep ->
  VName ->
  KernelBody rep ->
  TransformM rep (KernelBody rep)
transformKernelBody :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable
-> ExpMap rep
-> VName
-> KernelBody rep
-> TransformM rep (KernelBody rep)
transformKernelBody LayoutTable
perm_table ExpMap rep
expmap VName
seg_name (KernelBody BodyDec rep
b Stms rep
stms [KernelResult]
res) = do
  Stms rep
stms' <- LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
perm_table ExpMap rep
expmap Stms rep
stms
  StateT
  Replacements (BuilderT rep (State VNameSource)) (KernelBody rep)
-> Replacements -> TransformM rep (KernelBody rep)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ( VName
-> ArrayIndexTransform
     (StateT Replacements (BuilderT rep (State VNameSource)))
-> KernelBody rep
-> StateT
     Replacements (BuilderT rep (State VNameSource)) (KernelBody rep)
forall (m :: * -> *) rep.
(Monad m, Transform rep) =>
VName
-> ArrayIndexTransform m -> KernelBody rep -> m (KernelBody rep)
traverseKernelBodyArrayIndexes
        VName
seg_name
        (LayoutTable
-> ArrayIndexTransform
     (StateT Replacements (BuilderT rep (State VNameSource)))
forall (m :: * -> *).
MonadBuilder m =>
LayoutTable -> ArrayIndexTransform (StateT Replacements m)
ensureTransformedAccess LayoutTable
perm_table)
        (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
b Stms rep
stms' [KernelResult]
res)
    )
    Replacements
forall a. Monoid a => a
mempty

traverseKernelBodyArrayIndexes ::
  forall m rep.
  (Monad m, Transform rep) =>
  VName -> -- seg_name
  ArrayIndexTransform m ->
  KernelBody rep ->
  m (KernelBody rep)
traverseKernelBodyArrayIndexes :: forall (m :: * -> *) rep.
(Monad m, Transform rep) =>
VName
-> ArrayIndexTransform m -> KernelBody rep -> m (KernelBody rep)
traverseKernelBodyArrayIndexes VName
seg_name ArrayIndexTransform m
coalesce (KernelBody BodyDec rep
b Stms rep
kstms [KernelResult]
kres) =
  BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
b (Stms rep -> [KernelResult] -> KernelBody rep)
-> ([Stm rep] -> Stms rep)
-> [Stm rep]
-> [KernelResult]
-> KernelBody rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList
    ([Stm rep] -> [KernelResult] -> KernelBody rep)
-> m [Stm rep] -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm rep -> m (Stm rep)) -> [Stm rep] -> m [Stm rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm rep -> m (Stm rep)
onStm (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
kstms)
    m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
  where
    onLambda :: Lambda rep -> m (Lambda rep)
onLambda Lambda rep
lam =
      (\Body rep
body' -> Lambda rep
lam {lambdaBody = body'})
        (Body rep -> Lambda rep) -> m (Body rep) -> m (Lambda rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> m (Body rep)
onBody (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)

    onBody :: Body rep -> m (Body rep)
onBody (Body BodyDec rep
bdec Stms rep
stms Result
bres) = do
      Stms rep
stms' <- [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep) -> m [Stm rep] -> m (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm rep -> m (Stm rep)) -> [Stm rep] -> m [Stm rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm rep -> m (Stm rep)
onStm (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
      Body rep -> m (Body rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> m (Body rep)) -> Body rep -> m (Body rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
bdec Stms rep
stms' Result
bres

    onStm :: Stm rep -> m (Stm rep)
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
      Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec (Exp rep -> Stm rep)
-> (Maybe (VName, Slice SubExp) -> Exp rep)
-> Maybe (VName, Slice SubExp)
-> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> Exp rep
forall {rep}. Maybe (VName, Slice SubExp) -> Exp rep
oldOrNew (Maybe (VName, Slice SubExp) -> Stm rep)
-> m (Maybe (VName, Slice SubExp)) -> m (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform m
coalesce VName
seg_name VName
patternName VName
arr Slice SubExp
is
      where
        oldOrNew :: Maybe (VName, Slice SubExp) -> Exp rep
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
          BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
        oldOrNew (Just (VName
arr', Slice SubExp
is')) =
          BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'
        patternName :: VName
patternName = PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (LetDec rep) -> VName)
-> ([PatElem (LetDec rep)] -> PatElem (LetDec rep))
-> [PatElem (LetDec rep)]
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PatElem (LetDec rep)] -> PatElem (LetDec rep)
forall a. HasCallStack => [a] -> a
head ([PatElem (LetDec rep)] -> VName)
-> [PatElem (LetDec rep)] -> VName
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
    onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Exp rep
e) =
      Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec (Exp rep -> Stm rep) -> m (Exp rep) -> m (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper rep rep m -> Exp rep -> m (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep m
mapper Exp rep
e

    soac_mapper :: SOACMapper rep rep m
soac_mapper =
      SOACMapper Any Any m
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda = onLambda
        }

    mapper :: Mapper rep rep m
mapper =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
        { mapOnBody = const onBody,
          mapOnOp = onOp soac_mapper
        }

-- | Used to keep track of which pairs of arrays and permutations we have
-- already created manifests for, in order to avoid duplicates.
type Replacements = M.Map (VName, Permutation) VName

type ArrayIndexTransform m =
  VName -> -- seg_name (name of the SegThread expression's pattern)
  VName -> -- idx_name (name of the Index expression's pattern)
  VName -> -- arr (name of the array)
  Slice SubExp -> -- slice
  m (Maybe (VName, Slice SubExp))

ensureTransformedAccess ::
  (MonadBuilder m) =>
  LayoutTable ->
  ArrayIndexTransform (StateT Replacements m)
ensureTransformedAccess :: forall (m :: * -> *).
MonadBuilder m =>
LayoutTable -> ArrayIndexTransform (StateT Replacements m)
ensureTransformedAccess LayoutTable
perm_table VName
seg_name VName
idx_name VName
arr Slice SubExp
slice = do
  -- Check if the array has the optimal layout in memory.
  -- If it does not, replace it with a manifest to allocate
  -- it with the optimal layout
  case LayoutTable -> VName -> VName -> VName -> Maybe Permutation
lookupPermutation LayoutTable
perm_table VName
seg_name VName
idx_name VName
arr of
    Maybe Permutation
Nothing -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a. a -> StateT Replacements m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr, Slice SubExp
slice)
    Just Permutation
perm -> do
      Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
 -> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Permutation) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Permutation
perm)
      case Maybe VName
seen of
        -- Already created a manifest for this array + permutation.
        -- So, just replace the name and don't make a new manifest.
        Just VName
arr' -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a. a -> StateT Replacements m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
        Maybe VName
Nothing -> Permutation
-> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
forall {m :: * -> *} {b} {a}.
(MonadState (Map (VName, b) a) m, Ord b) =>
b -> a -> m (Maybe (a, Slice SubExp))
replace Permutation
perm (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (m :: * -> *) a. Monad m => m a -> StateT Replacements m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Permutation -> VName -> m VName
forall {m :: * -> *}.
MonadBuilder m =>
Permutation -> VName -> m VName
manifest Permutation
perm VName
arr)
  where
    replace :: b -> a -> m (Maybe (a, Slice SubExp))
replace b
perm a
arr' = do
      -- Store the fact that we have seen this array + permutation
      -- so we don't make duplicate manifests
      (Map (VName, b) a -> Map (VName, b) a) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map (VName, b) a -> Map (VName, b) a) -> m ())
-> (Map (VName, b) a -> Map (VName, b) a) -> m ()
forall a b. (a -> b) -> a -> b
$ (VName, b) -> a -> Map (VName, b) a -> Map (VName, b) a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, b
perm) a
arr'
      -- Return the new manifest
      Maybe (a, Slice SubExp) -> m (Maybe (a, Slice SubExp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, Slice SubExp) -> m (Maybe (a, Slice SubExp)))
-> Maybe (a, Slice SubExp) -> m (Maybe (a, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (a, Slice SubExp) -> Maybe (a, Slice SubExp)
forall a. a -> Maybe a
Just (a
arr', Slice SubExp
slice)

    manifest :: Permutation -> VName -> m VName
manifest Permutation
perm VName
array =
      [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
array [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_coalesced") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (Permutation -> VName -> BasicOp
Manifest Permutation
perm VName
array)

lookupPermutation :: LayoutTable -> VName -> IndexExprName -> VName -> Maybe Permutation
lookupPermutation :: LayoutTable -> VName -> VName -> VName -> Maybe Permutation
lookupPermutation LayoutTable
perm_table VName
seg_name VName
idx_name VName
arr_name =
  case VName
-> Map VName (Map ArrayName (Map VName Permutation))
-> Maybe (Map ArrayName (Map VName Permutation))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
seg_name ((SegOpName -> VName)
-> LayoutTable -> Map VName (Map ArrayName (Map VName Permutation))
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys SegOpName -> VName
vnameFromSegOp LayoutTable
perm_table) of
    Maybe (Map ArrayName (Map VName Permutation))
Nothing -> Maybe Permutation
forall a. Maybe a
Nothing
    Just Map ArrayName (Map VName Permutation)
arrayNameMap ->
      -- Look for the current array
      case VName
-> Map VName (Map VName Permutation)
-> Maybe (Map VName Permutation)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr_name ((ArrayName -> VName)
-> Map ArrayName (Map VName Permutation)
-> Map VName (Map VName Permutation)
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys (\(VName
n, [BodyType]
_, Permutation
_) -> VName
n) Map ArrayName (Map VName Permutation)
arrayNameMap) of
        Maybe (Map VName Permutation)
Nothing -> Maybe Permutation
forall a. Maybe a
Nothing
        Just Map VName Permutation
idxs -> VName -> Map VName Permutation -> Maybe Permutation
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
idx_name Map VName Permutation
idxs

transformStm ::
  (Transform rep, BuilderOps rep) =>
  (LayoutTable, ExpMap rep) ->
  Stm rep ->
  TransformM rep (LayoutTable, ExpMap rep)
transformStm :: forall rep.
(Transform rep, BuilderOps rep) =>
(LayoutTable, ExpMap rep)
-> Stm rep -> TransformM rep (LayoutTable, ExpMap rep)
transformStm (LayoutTable
perm_table, ExpMap rep
expmap) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Op Op rep
op)) = LayoutTable
-> ExpMap rep
-> Stm rep
-> Op rep
-> TransformM rep (LayoutTable, ExpMap rep)
forall rep.
Transform rep =>
LayoutTable
-> ExpMap rep
-> Stm rep
-> Op rep
-> TransformM rep (LayoutTable, ExpMap rep)
transformOp LayoutTable
perm_table ExpMap rep
expmap (Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op Op rep
op)) Op rep
op
transformStm (LayoutTable
perm_table, ExpMap rep
expmap) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
  Exp rep
e' <- Mapper rep rep (BuilderT rep (State VNameSource))
-> Exp rep -> BuilderT rep (State VNameSource) (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (LayoutTable
-> ExpMap rep -> Mapper rep rep (BuilderT rep (State VNameSource))
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Mapper rep rep (TransformM rep)
transform LayoutTable
perm_table ExpMap rep
expmap) Exp rep
e
  let stm' :: Stm rep
stm' = Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e'
  Stm (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm rep
Stm (Rep (BuilderT rep (State VNameSource)))
stm'
  (LayoutTable, ExpMap rep)
-> TransformM rep (LayoutTable, ExpMap rep)
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LayoutTable
perm_table, [(VName, Stm rep)] -> ExpMap rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm rep
stm') | VName
name <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat] ExpMap rep -> ExpMap rep -> ExpMap rep
forall a. Semigroup a => a -> a -> a
<> ExpMap rep
expmap)

transformStms ::
  (Transform rep, BuilderOps rep) =>
  LayoutTable ->
  ExpMap rep ->
  Stms rep ->
  TransformM rep (Stms rep)
transformStms :: forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
perm_table ExpMap rep
expmap Stms rep
stms =
  BuilderT rep (State VNameSource) ()
-> BuilderT
     rep
     (State VNameSource)
     (Stms (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (BuilderT rep (State VNameSource) ()
 -> BuilderT
      rep
      (State VNameSource)
      (Stms (Rep (BuilderT rep (State VNameSource)))))
-> BuilderT rep (State VNameSource) ()
-> BuilderT
     rep
     (State VNameSource)
     (Stms (Rep (BuilderT rep (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ ((LayoutTable, ExpMap rep)
 -> Stm rep
 -> BuilderT rep (State VNameSource) (LayoutTable, ExpMap rep))
-> (LayoutTable, ExpMap rep)
-> Stms rep
-> BuilderT rep (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (LayoutTable, ExpMap rep)
-> Stm rep
-> BuilderT rep (State VNameSource) (LayoutTable, ExpMap rep)
forall rep.
(Transform rep, BuilderOps rep) =>
(LayoutTable, ExpMap rep)
-> Stm rep -> TransformM rep (LayoutTable, ExpMap rep)
transformStm (LayoutTable
perm_table, ExpMap rep
expmap) Stms rep
stms