{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Pass.KernelBabysitting (babysitKernels) where

import Control.Arrow (first)
import Control.Monad.State.Strict
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.Kernels hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    PatElem,
    Pattern,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Util

-- | The pass definition.
babysitKernels :: Pass Kernels Kernels
babysitKernels :: Pass Kernels Kernels
babysitKernels =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
    String
"babysit kernels"
    String
"Transpose kernel input arrays for better performance."
    ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$ (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (f :: * -> *).
MonadFreshNames f =>
Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms
  where
    onStms :: Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms = do
      let m :: BinderT Kernels (State VNameSource) (Stms Kernels)
m = Scope Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BinderT Kernels (State VNameSource) (Stms Kernels)
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
forall a. Monoid a => a
mempty Stms Kernels
stms
      ((Stms Kernels, Stms Kernels) -> Stms Kernels)
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms Kernels, Stms Kernels) -> Stms Kernels
forall a b. (a, b) -> a
fst (f (Stms Kernels, Stms Kernels) -> f (Stms Kernels))
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
 -> f (Stms Kernels, Stms Kernels))
-> (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels, Stms Kernels)
-> VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BinderT Kernels (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT Kernels (State VNameSource) (Stms Kernels)
m Scope Kernels
forall k a. Map k a
M.empty)

type BabysitM = Binder Kernels

transformStms :: ExpMap -> Stms Kernels -> BabysitM (Stms Kernels)
transformStms :: ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms = BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels
     (State VNameSource)
     (Stms (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (BinderT Kernels (State VNameSource) ()
 -> BinderT
      Kernels
      (State VNameSource)
      (Stms (Lore (BinderT Kernels (State VNameSource)))))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels
     (State VNameSource)
     (Stms (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (ExpMap
 -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> Stms Kernels -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap Stms Kernels
stms

transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap (Body () Stms Kernels
stms Result
res) = do
  Stms Kernels
stms' <- ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms
  Body Kernels -> BabysitM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels -> BabysitM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms' Result
res

-- | Map from variable names to defining expression.  We use this to
-- hackily determine whether something is transposed or otherwise
-- funky in memory (and we'd prefer it not to be).  If we cannot find
-- it in the map, we just assume it's all good.  HACK and FIXME, I
-- suppose.  We really should do this at the memory level.
type ExpMap = M.Map VName (Stm Kernels)

nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
name ExpMap
m =
  case VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name ExpMap
m of
    Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Opaque (Var VName
arr)))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
    Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Reshape ShapeChange SubExp
_ VName
arr))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    Just (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
_ (Op (SegOp (SegMap _ _ ts _)))) ->
      (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall shape dec u.
(ArrayShape shape, Typed dec) =>
(PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear
        ((PatElemT Type, Type) -> Maybe (Maybe [Int]))
-> Maybe (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((PatElemT Type, Type) -> Bool)
-> [(PatElemT Type, Type)] -> Maybe (PatElemT Type, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
          ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
name) (VName -> Bool)
-> ((PatElemT Type, Type) -> VName)
-> (PatElemT Type, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT Type -> VName)
-> ((PatElemT Type, Type) -> PatElemT Type)
-> (PatElemT Type, Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, Type) -> PatElemT Type
forall a b. (a, b) -> a
fst)
          ([PatElemT Type] -> [Type] -> [(PatElemT Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat) [Type]
ts)
    Maybe (Stm Kernels)
_ -> Maybe (Maybe [Int])
forall a. Maybe a
Nothing
  where
    nonlinear :: (PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElemT dec
pe, TypeBase shape u
t)
      | Int
inner_r <- TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t,
        Int
inner_r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
        let outer_r :: Int
outer_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inner_r
        Maybe [Int] -> Maybe (Maybe [Int])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int
inner_r .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
outer_r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      | Bool
otherwise = Maybe (Maybe [Int])
forall a. Maybe a
Nothing

transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm :: ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp op)))
  -- FIXME: We only make coalescing optimisations for SegThread
  -- SegOps, because that's what the analysis assumes.  For SegGroup
  -- we should probably look at the component SegThreads, but it
  -- apparently hasn't come up in practice yet.
  | SegThread {} <- SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op = do
    let mapper :: SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper =
          SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
mapOnSegOpBody =
                ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap (SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op) (SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op)
            }
    SegOp SegLevel Kernels
op' <- SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
-> SegOp SegLevel Kernels
-> BinderT Kernels (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper SegOp SegLevel Kernels
op
    let stm' :: Stm Kernels
stm' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel Kernels
op'
    Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
stm'
    ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm Kernels
stm') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) = do
  ExpT Kernels
e' <- Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
-> ExpT Kernels -> BabysitM (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap) ExpT Kernels
e
  let bnd' :: Stm Kernels
bnd' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e'
  Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
bnd'
  ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm Kernels
bnd') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap

transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform :: ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap =
  Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> Body Kernels -> BabysitM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> BabysitM (Body Kernels) -> BabysitM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BabysitM (Body Kernels) -> BabysitM (Body Kernels))
-> (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels
-> BabysitM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap}

transformKernelBody ::
  ExpMap ->
  SegLevel ->
  SegSpace ->
  KernelBody Kernels ->
  BabysitM (KernelBody Kernels)
transformKernelBody :: ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap SegLevel
lvl SegSpace
space KernelBody Kernels
kbody = do
  -- Go spelunking for accesses to arrays that are defined outside the
  -- kernel body and where the indices are kernel thread indices.
  Scope Kernels
scope <- BinderT Kernels (State VNameSource) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let thread_gids :: [VName]
thread_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      thread_local :: Names
thread_local = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
thread_gids
      free_ker_vars :: Names
free_ker_vars = KernelBody Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody Kernels
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
  SubExp
num_threads <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp
          (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
          (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
          (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
  StateT
  Replacements
  (BinderT Kernels (State VNameSource))
  (KernelBody Kernels)
-> Replacements
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ( Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform
     (StateT Replacements (BinderT Kernels (State VNameSource)))
-> KernelBody Kernels
-> StateT
     Replacements
     (BinderT Kernels (State VNameSource))
     (KernelBody Kernels)
forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes
        Names
free_ker_vars
        Names
thread_local
        (Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space)
        (ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform
     (StateT Replacements (BinderT Kernels (State VNameSource)))
forall (m :: * -> *).
MonadBinder m =>
ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) SubExp
num_threads)
        KernelBody Kernels
kbody
    )
    Replacements
forall a. Monoid a => a
mempty
  where
    getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList ([VName] -> Names) -> (SegSpace -> [VName]) -> SegSpace -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> (SegSpace -> Map VName (NameInfo Any)) -> SegSpace -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace

type ArrayIndexTransform m =
  Names ->
  (VName -> Bool) -> -- thread local?
  (VName -> SubExp -> Bool) -> -- variant to a certain gid (given as first param)?
  (SubExp -> Maybe SubExp) -> -- split substitution?
  Scope Kernels -> -- type environment
  VName ->
  Slice SubExp ->
  m (Maybe (VName, Slice SubExp))

traverseKernelBodyArrayIndexes ::
  (Applicative f, Monad f) =>
  Names ->
  Names ->
  Scope Kernels ->
  ArrayIndexTransform f ->
  KernelBody Kernels ->
  f (KernelBody Kernels)
traverseKernelBodyArrayIndexes :: Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope Kernels
outer_scope ArrayIndexTransform f
f (KernelBody () Stms Kernels
kstms [KernelResult]
kres) =
  BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ([Stm Kernels] -> Stms Kernels)
-> [Stm Kernels]
-> [KernelResult]
-> KernelBody Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    ([Stm Kernels] -> [KernelResult] -> KernelBody Kernels)
-> f [Stm Kernels] -> f ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm
          ( VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
forall a. Monoid a => a
mempty Stms Kernels
kstms,
            Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
kstms,
            Scope Kernels
outer_scope
          )
      )
      (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
kstms)
    f ([KernelResult] -> KernelBody Kernels)
-> f [KernelResult] -> f (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
  where
    onLambda :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) LambdaT Kernels
lam =
      (\Body Kernels
body' -> LambdaT Kernels
lam {lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body'})
        (Body Kernels -> LambdaT Kernels)
-> f (Body Kernels) -> f (LambdaT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope') (LambdaT Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam)
      where
        scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams (LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam)

    onBody :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Body BodyDec Kernels
bdec Stms Kernels
stms Result
bres) = do
      Stms Kernels
stms' <- [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> f [Stm Kernels] -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance', Map VName SubExp
szsubst', Scope Kernels
scope')) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
      Body Kernels -> f (Body Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body Kernels -> f (Body Kernels))
-> Body Kernels -> f (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec Kernels
bdec Stms Kernels
stms' Result
bres
      where
        variance' :: VarianceTable
variance' = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
variance Stms Kernels
stms
        szsubst' :: Map VName SubExp
szsubst' = Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
stms Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
szsubst
        scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms

    onStm :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
_) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
      Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> (Maybe (VName, Slice SubExp) -> ExpT Kernels)
-> Maybe (VName, Slice SubExp)
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew (Maybe (VName, Slice SubExp) -> Stm Kernels)
-> f (Maybe (VName, Slice SubExp)) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
is
      where
        oldOrNew :: Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
          BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
        oldOrNew (Just (VName
arr', Slice SubExp
is')) =
          BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'

        isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
          VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
        isGidVariant VName
_ SubExp
_ = Bool
False

        isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
          Names
thread_variant
            Names -> Names -> Bool
`namesIntersect` Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance

        sizeSubst :: SubExp -> Maybe SubExp
sizeSubst (Constant PrimValue
v) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
        sizeSubst (Var VName
v)
          | VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outer_scope = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
          | Just SubExp
v' <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
szsubst = SubExp -> Maybe SubExp
sizeSubst SubExp
v'
          | Bool
otherwise = Maybe SubExp
forall a. Maybe a
Nothing
    onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec ExpT Kernels
e) =
      Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> f (ExpT Kernels) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper Kernels Kernels f -> ExpT Kernels -> f (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope)) ExpT Kernels
e

    onOp :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx (OtherOp SOAC Kernels
soac) =
      SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> f (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper Kernels Kernels f -> SOAC Kernels -> f (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any f
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: LambdaT Kernels -> f (LambdaT Kernels)
mapOnSOACLambda = (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable, Map VName SubExp, Scope Kernels)
ctx} SOAC Kernels
soac
    onOp (VarianceTable, Map VName SubExp, Scope Kernels)
_ HostOp Kernels (SOAC Kernels)
op = HostOp Kernels (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp Kernels (SOAC Kernels)
op

    mapper :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable, Map VName SubExp, Scope Kernels)
ctx =
      Mapper Kernels Kernels f
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope Kernels -> Body Kernels -> f (Body Kernels)
mapOnBody = (Body Kernels -> f (Body Kernels))
-> Scope Kernels -> Body Kernels -> f (Body Kernels)
forall a b. a -> b -> a
const ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable, Map VName SubExp, Scope Kernels)
ctx),
          mapOnOp :: Op Kernels -> f (Op Kernels)
mapOnOp = (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx
        }

    mkSizeSubsts :: Stms Kernels -> Map VName SubExp
mkSizeSubsts = (Stm Kernels -> Map VName SubExp)
-> Stms Kernels -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm Kernels -> Map VName SubExp
forall lore lore op.
(Op lore ~ HostOp lore op) =>
Stm lore -> Map VName SubExp
mkStmSizeSubst
      where
        mkStmSizeSubst :: Stm lore -> Map VName SubExp
mkStmSizeSubst (Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (Op (SizeOp (SplitSpace _ _ _ elems_per_i)))) =
          VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe) SubExp
elems_per_i
        mkStmSizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty

type Replacements = M.Map (VName, Slice SubExp) VName

ensureCoalescedAccess ::
  MonadBinder m =>
  ExpMap ->
  [(VName, SubExp)] ->
  SubExp ->
  ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess
  ExpMap
expmap
  [(VName, SubExp)]
thread_space
  SubExp
num_threads
  Names
free_ker_vars
  VName -> Bool
isThreadLocal
  VName -> SubExp -> Bool
isGidVariant
  SubExp -> Maybe SubExp
sizeSubst
  Scope Kernels
outer_scope
  VName
arr
  Slice SubExp
slice = do
    Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
 -> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)

    case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, NameInfo Kernels -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo Kernels -> Type)
-> Maybe (NameInfo Kernels) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Scope Kernels -> Maybe (NameInfo Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope Kernels
outer_scope) of
      -- Already took care of this case elsewhere.
      (Just VName
arr', Bool
_, Maybe Type
_) ->
        Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
      (Maybe VName
Nothing, Bool
False, Just Type
t)
        -- We are fully indexing the array with thread IDs, but the
        -- indices are in a permuted order.
        | Just Result
is <- Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
          Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
          Just Result
is' <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is,
          Just [Int]
perm <- Result
is' Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
is ->
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)
        -- Check whether the access is already coalesced because of a
        -- previous rearrange being applied to the current array:
        -- 1. get the permutation of the source-array rearrange
        -- 2. apply it to the slice
        -- 3. check that the innermost index is actually the gid
        --    of the innermost kernel dimension.
        -- If so, the access is already coalesced, nothing to do!
        -- (Cosmin's Heuristic.)
        | Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName
inner_gid <- [VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids,
          Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
          Slice SubExp
slice' <- (Int -> DimIndex SubExp) -> [Int] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp
slice Slice SubExp -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm,
          DimFix SubExp
inner_ind <- Slice SubExp -> DimIndex SubExp
forall a. [a] -> a
last Slice SubExp
slice',
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
          Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
        -- We are not fully indexing an array, but the remaining slice
        -- is invariant to the innermost-kernel dimension. We assume
        -- the remaining slice will be sequentially streamed, hence
        -- tiling will be applied later and will solve coalescing.
        -- Hence nothing to do at this point. (Cosmin's Heuristic.)
        | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
          Maybe (Stm Kernels)
Nothing <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
          Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
          Bool -> Bool
not ([VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`nameIn` (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice)) ->
          Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
        -- We are not fully indexing the array, and the indices are not
        -- a proper prefix of the thread indices, and some indices are
        -- thread local, so we assume (HEURISTIC!)  that the remaining
        -- dimensions will be traversed sequentially.
        | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
          Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is) -> do
          let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)

        -- We are taking a slice of the array with a unit stride.  We
        -- assume that the slice will be traversed sequentially.
        --
        -- We will really want to treat the sliced dimension like two
        -- dimensions so we can transpose them.  This may require
        -- padding.
        | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
          [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Bool) -> Result -> Result -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(==) Result
is (Result -> [Bool]) -> Result -> [Bool]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids,
          DimSlice SubExp
offset SubExp
len (Constant PrimValue
stride) : Slice SubExp
_ <- Slice SubExp
rem_slice,
          SubExp -> Bool
isThreadLocalSubExp SubExp
offset,
          Just {} <- SubExp -> Maybe SubExp
sizeSubst SubExp
len,
          PrimValue -> Bool
oneIsh PrimValue
stride -> do
          let num_chunks :: PrimExp VName
num_chunks =
                if Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is
                  then TPrimExp Int32 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName -> PrimExp VName)
-> TPrimExp Int32 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_threads
                  else
                    TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                      [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
                        (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                          Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Result
thread_gdims
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Int -> SubExp -> PrimExp VName -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Type
t) PrimExp VName
num_chunks VName
arr)

        -- Everything is fine... assuming that the array is in row-major
        -- order!  Make sure that is the case.
        | Just {} <- VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap ->
          case Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
            Just Result
is
              | Just Result
_ <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is ->
                VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
              | Bool
otherwise ->
                Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
            Maybe Result
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
      (Maybe VName, Bool, Maybe Type)
_ -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
    where
      ([VName]
thread_gids, Result
thread_gdims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space

      replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
        (Replacements -> Replacements) -> StateT Replacements m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Replacements -> Replacements) -> StateT Replacements m ())
-> (Replacements -> Replacements) -> StateT Replacements m ()
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> VName -> Replacements -> Replacements
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
        Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)

      isThreadLocalSubExp :: SubExp -> Bool
isThreadLocalSubExp (Var VName
v) = VName -> Bool
isThreadLocal VName
v
      isThreadLocalSubExp Constant {} = Bool
False

-- Heuristic for avoiding rearranging too small arrays.
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = (Bool, Int32) -> Bool
forall a b. (a, b) -> a
fst ((Bool, Int32) -> Bool)
-> (Slice SubExp -> (Bool, Int32)) -> Slice SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, Int32) -> SubExp -> (Bool, Int32))
-> (Bool, Int32) -> Result -> (Bool, Int32)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
bs) (Result -> (Bool, Int32))
-> (Slice SubExp -> Result) -> Slice SubExp -> (Bool, Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims
  where
    comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x)
    comb (Bool
_, Int32
x) SubExp
_ = (Bool
False, Int32
x)

splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> (Result, Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix SubExp
i : Slice SubExp
is) = (Result -> Result)
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
i SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:) ((Result, Slice SubExp) -> (Result, Slice SubExp))
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
is
splitSlice Slice SubExp
is = ([], Slice SubExp
is)

allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = Bool
True
allDimAreSlice (DimFix SubExp
_ : Slice SubExp
_) = Bool
False
allDimAreSlice (DimIndex SubExp
_ : Slice SubExp
is) = Slice SubExp -> Bool
allDimAreSlice Slice SubExp
is

-- Try to move thread indexes into their proper position.
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant Result
tgids Result
is
  -- Do Nothing if:
  -- 1. any of the indices is a constant or a kernel free variable
  --    (because it would transpose a bigger array then needed -- big overhead).
  -- 2. the innermost index is variant to the innermost-thread gid
  --    (because access is likely to be already coalesced)
  -- 3. the indexes are a prefix of the thread indexes, because that
  -- means multiple threads will be accessing the same element.
  | (SubExp -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt Result
is =
    Maybe Result
forall a. Maybe a
Nothing
  | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) (Result -> [VName]
subExpVars Result
is) =
    Maybe Result
forall a. Maybe a
Nothing
  | Result
is Result -> Result -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Result
tgids =
    Maybe Result
forall a. Maybe a
Nothing
  | Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
tgids),
    Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
    Var VName
innergid <- Result -> SubExp
forall a. [a] -> a
last Result
tgids,
    Int
num_is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid (Result -> SubExp
forall a. [a] -> a
last Result
is) =
    Result -> Maybe Result
forall a. a -> Maybe a
Just Result
is
  -- 3. Otherwise try fix coalescing
  | Bool
otherwise =
    Result -> Maybe Result
forall a. a -> Maybe a
Just (Result -> Maybe Result) -> Result -> Maybe Result
forall a b. (a -> b) -> a -> b
$ Result -> Result
forall a. [a] -> [a]
reverse (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> (Int, SubExp) -> Result)
-> Result -> [(Int, SubExp)] -> Result
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Result -> (Int, SubExp) -> Result
move (Result -> Result
forall a. [a] -> [a]
reverse Result
is) ([(Int, SubExp)] -> Result) -> [(Int, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] (Result -> Result
forall a. [a] -> [a]
reverse Result
tgids)
  where
    num_is :: Int
num_is = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is

    move :: Result -> (Int, SubExp) -> Result
move Result
is_rev (Int
i, SubExp
tgid)
      -- If tgid is in is_rev anywhere but at position i, and
      -- position i exists, we move it to position i instead.
      | Just Int
j <- SubExp -> Result -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid Result
is_rev,
        Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j,
        Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_is =
        Int -> Int -> Result -> Result
forall a b t.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j Result
is_rev
      | Bool
otherwise =
        Result
is_rev

    swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
      | Just t
ix <- a -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
        Just t
jx <- b -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
        a -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx ([t] -> [t]) -> [t] -> [t]
forall a b. (a -> b) -> a -> b
$ b -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
      | Bool
otherwise =
        String -> [t]
forall a. HasCallStack => String -> a
error (String -> [t]) -> String -> [t]
forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a, b, [t]) -> String
forall a. Show a => a -> String
show (a
i, b
j, [t]
l)

    update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_ : [t]
ys) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
ys
    update t
i t
x (t
y : [t]
ys) = t
y t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
i t -> t -> t
forall a. Num a => a -> a -> a
-t
1) t
x [t]
ys
    update t
_ t
_ [] = String -> [t]
forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"

    isCt :: SubExp -> Bool
    isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
    isCt (Var VName
_) = Bool
False

coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
  [Int
num_is .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_is Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

rearrangeInput ::
  MonadBinder m =>
  Maybe (Maybe [Int]) ->
  [Int] ->
  VName ->
  m VName
rearrangeInput :: Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
  | [Int]
current_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr -- Already has desired representation.
rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr -- We don't know the current
  -- representation, but the indexing
  -- is linear, so let's hope the
  -- array is too.
rearrangeInput (Just Just {}) [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr -- We just want a row-major array, no tricks.
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
  -- We may first manifest the array to ensure that it is flat in
  -- memory.  This is sometimes unnecessary, in which case the copy
  -- will hopefully be removed by the simplifier.
  VName
manifested <- if Maybe (Maybe [Int]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr else VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_coalesced") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested

rowMajorArray ::
  MonadBinder m =>
  VName ->
  m VName
rowMajorArray :: VName -> m VName
rowMajorArray VName
arr = do
  Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> m Type -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] VName
arr

rearrangeSlice ::
  MonadBinder m =>
  Int ->
  SubExp ->
  PrimExp VName ->
  VName ->
  m VName
rearrangeSlice :: Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice Int
d SubExp
w PrimExp VName
num_chunks VName
arr = do
  SubExp
num_chunks' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"num_chunks" PrimExp VName
num_chunks

  (SubExp
w_padded, SubExp
padding) <- SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
num_chunks'

  SubExp
per_chunk <-
    String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_chunk" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w_padded SubExp
num_chunks'
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  VName
arr_padded <- SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t
  SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk (VName -> String
baseString VName
arr) VName
arr_padded Type
arr_t
  where
    padArray :: SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t = do
      let arr_shape :: Shape
arr_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
          padding_shape :: Shape
padding_shape = Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
d Shape
arr_shape SubExp
padding
      VName
arr_padding <-
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padding") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_t) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
padding_shape)
      String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padded") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d VName
arr [VName
arr_padding] SubExp
w_padded

    rearrange :: SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk String
arr_name VName
arr_padded Type
arr_t = do
      let arr_dims :: Result
arr_dims = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
          pre_dims :: Result
pre_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
d Result
arr_dims
          post_dims :: Result
post_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Result
arr_dims
          extradim_shape :: Shape
extradim_shape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Result
pre_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
num_chunks', SubExp
per_chunk] Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
post_dims
          tr_perm :: [Int]
tr_perm = [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d) ([Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
extradim_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
d] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
      VName
arr_extradim <-
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (Result -> ShapeChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
extradim_shape) VName
arr_padded
      VName
arr_extradim_tr <-
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
tr_perm VName
arr_extradim
      VName
arr_inv_tr <-
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
            ShapeChange SubExp -> VName -> BasicOp
Reshape
              ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion Result
pre_dims ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (SubExp
w_padded SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: Result
post_dims))
              VName
arr_extradim_tr
      String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr_init")
        (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
eSliceArray Int
d VName
arr_inv_tr (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)

paddedScanReduceInput ::
  MonadBinder m =>
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
paddedScanReduceInput :: SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
stride = do
  SubExp
w_padded <-
    String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padded_size"
      (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType
-> m (ExpT (Lore m)) -> m (ExpT (Lore m)) -> m (ExpT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf IntType
Int64 (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
stride)
  SubExp
padding <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padding" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w_padded SubExp
w
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
w_padded, SubExp
padding)

--- Computing variance.

type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
t = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> [Stm Kernels] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
t ([Stm Kernels] -> VarianceTable)
-> (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> VarianceTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList

varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
variance Stm Kernels
bnd =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
  where
    add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)