module Futhark.Optimise.ArrayLayout.Layout
  ( layoutTableFromIndexTable,
    Layout,
    Permutation,
    LayoutTable,

    -- * Exposed for testing
    commonPermutationEliminators,
  )
where

import Control.Monad (join)
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.AccessPattern
import Futhark.Analysis.PrimExp.Table (PrimExpTable)
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.IR.MCMem
import Futhark.Util (mininum)

type Permutation = [Int]

type LayoutTable =
  M.Map
    SegOpName
    ( M.Map
        ArrayName
        (M.Map IndexExprName Permutation)
    )

class Layout rep where
  -- | Produce a coalescing permutation that will be used to create a
  -- manifest of the array. Returns Nothing if the array is already in
  -- the optimal layout or if the array access is too complex to
  -- confidently determine the optimal layout. Map each list of
  -- 'DimAccess' in the IndexTable to a permutation in a generic way
  -- that can be handled uniquely by each backend.
  permutationFromDimAccess ::
    PrimExpTable ->
    SegOpName ->
    ArrayName ->
    IndexExprName ->
    [DimAccess rep] ->
    Maybe Permutation

isInscrutableExp :: PrimExp VName -> Bool
isInscrutableExp :: PrimExp VName -> Bool
isInscrutableExp (LeafExp VName
_ PrimType
_) = Bool
False
isInscrutableExp (ValueExp PrimValue
_) = Bool
False
isInscrutableExp (BinOpExp BinOp
_ PrimExp VName
a PrimExp VName
b) =
  PrimExp VName -> Bool
isInscrutableExp PrimExp VName
a Bool -> Bool -> Bool
|| PrimExp VName -> Bool
isInscrutableExp PrimExp VName
b
isInscrutableExp (UnOpExp UnOp
_ PrimExp VName
a) =
  PrimExp VName -> Bool
isInscrutableExp PrimExp VName
a
isInscrutableExp PrimExp VName
_ = Bool
True

isInscrutable :: PrimExp VName -> Bool -> Bool
isInscrutable :: PrimExp VName -> Bool -> Bool
isInscrutable op :: PrimExp VName
op@(BinOpExp {}) Bool
counter =
  if Bool
counter
    then -- Calculate stride and offset for loop-counters and thread-IDs
    case PrimExp VName -> Maybe (Int, Int)
forall l. PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset PrimExp VName
op of
      -- Maximum allowable stride, might need tuning.
      Just (Int
s, Int
_) -> Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
8
      Maybe (Int, Int)
Nothing -> PrimExp VName -> Bool
isInscrutableExp PrimExp VName
op
    else PrimExp VName -> Bool
isInscrutableExp PrimExp VName
op
isInscrutable PrimExp VName
op Bool
_ = PrimExp VName -> Bool
isInscrutableExp PrimExp VName
op

reduceStrideAndOffset :: PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset :: forall l. PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset (LeafExp l
_ PrimType
_) = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
1, Int
0)
reduceStrideAndOffset (BinOpExp BinOp
oper PrimExp l
a PrimExp l
b) = case (PrimExp l
a, PrimExp l
b) of
  (ValueExp (IntValue IntValue
v), PrimExp l
_) -> IntValue -> PrimExp l -> Maybe (Int, Int)
forall {l}. IntValue -> PrimExp l -> Maybe (Int, Int)
reduce IntValue
v PrimExp l
b
  (PrimExp l
_, ValueExp (IntValue IntValue
v)) -> IntValue -> PrimExp l -> Maybe (Int, Int)
forall {l}. IntValue -> PrimExp l -> Maybe (Int, Int)
reduce IntValue
v PrimExp l
a
  (PrimExp l, PrimExp l)
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing
  where
    reduce :: IntValue -> PrimExp l -> Maybe (Int, Int)
reduce IntValue
v (LeafExp l
_ PrimType
_) =
      case BinOp
oper of
        Add IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
1, IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v)
        Sub IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
1, -IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v)
        Mul IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v, Int
0)
        BinOp
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing
    reduce IntValue
v op :: PrimExp l
op@(BinOpExp {}) =
      case PrimExp l -> Maybe (Int, Int)
forall l. PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset PrimExp l
op of
        Maybe (Int, Int)
Nothing -> Maybe (Int, Int)
forall a. Maybe a
Nothing
        Just (Int
s, Int
o) -> case BinOp
oper of
          Add IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
s, Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v)
          Sub IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
s, Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
- IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v)
          Mul IntType
_ Overflow
_ -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
* IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v, Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
* IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
v)
          BinOp
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing
    reduce IntValue
_ (UnOpExp UnOp
Not PrimExp l
_) = Maybe (Int, Int)
forall a. Maybe a
Nothing
    reduce IntValue
_ (UnOpExp (Complement IntType
_) PrimExp l
_) = Maybe (Int, Int)
forall a. Maybe a
Nothing
    reduce IntValue
_ (UnOpExp (Abs IntType
_) PrimExp l
_) = Maybe (Int, Int)
forall a. Maybe a
Nothing
    reduce IntValue
_ (UnOpExp UnOp
_ PrimExp l
sub_op) = PrimExp l -> Maybe (Int, Int)
forall l. PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset PrimExp l
sub_op
    reduce IntValue
_ (ConvOpExp ConvOp
_ PrimExp l
sub_op) = PrimExp l -> Maybe (Int, Int)
forall l. PrimExp l -> Maybe (Int, Int)
reduceStrideAndOffset PrimExp l
sub_op
    reduce IntValue
_ PrimExp l
_ = Maybe (Int, Int)
forall a. Maybe a
Nothing
reduceStrideAndOffset PrimExp l
_ = Maybe (Int, Int)
forall a. Maybe a
Nothing

-- | Reasons common to all backends to not manifest an array.
commonPermutationEliminators :: [Int] -> [BodyType] -> Bool
commonPermutationEliminators :: [Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int]
perm [BodyType]
nest = do
  -- Don't manifest if the permutation is the permutation is invalid
  let is_invalid_perm :: Bool
is_invalid_perm = Bool -> Bool
not ([Int] -> [Int]
forall a. Ord a => [a] -> [a]
L.sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`L.isPrefixOf` [Int
0 ..])
      -- Don't manifest if the permutation is the identity permutation
      is_identity :: Bool
is_identity = [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`L.isPrefixOf` [Int
0 ..]
      -- or is not a transpose.
      inefficient_transpose :: Bool
inefficient_transpose = Maybe (Int, Int, Int) -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe (Int, Int, Int) -> Bool) -> Maybe (Int, Int, Int) -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm
      -- or if the last idx remains last
      static_last_idx :: Bool
static_last_idx = [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      -- Don't manifest if the array is defined inside a segOp
      inside_undesired :: Bool
inside_undesired = (BodyType -> Bool) -> [BodyType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any BodyType -> Bool
undesired [BodyType]
nest

  Bool
is_invalid_perm
    Bool -> Bool -> Bool
|| Bool
is_identity
    Bool -> Bool -> Bool
|| Bool
inefficient_transpose
    Bool -> Bool -> Bool
|| Bool
static_last_idx
    Bool -> Bool -> Bool
|| Bool
inside_undesired
  where
    undesired :: BodyType -> Bool
    undesired :: BodyType -> Bool
undesired BodyType
bodyType = case BodyType
bodyType of
      SegOpName SegOpName
_ -> Bool
True
      BodyType
_ -> Bool
False

sortMC :: [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortMC :: forall {k} (rep :: k).
[(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortMC =
  ((Int, DimAccess rep) -> (Int, DimAccess rep) -> Ordering)
-> [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (Int, DimAccess rep) -> (Int, DimAccess rep) -> Ordering
forall {k} {k} {b} {rep :: k} {rep :: k}.
Ord b =>
(b, DimAccess rep) -> (b, DimAccess rep) -> Ordering
dimdexMCcmp
  where
    dimdexMCcmp :: (b, DimAccess rep) -> (b, DimAccess rep) -> Ordering
dimdexMCcmp (b
ia, DimAccess rep
a) (b
ib, DimAccess rep
b) = do
      let aggr1 :: Maybe (VarType, Int, b)
aggr1 =
            (Maybe (VarType, Int, b)
 -> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b))
-> Maybe (VarType, Int, b)
-> [Maybe (VarType, Int, b)]
-> Maybe (VarType, Int, b)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (VarType, Int, b)
-> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b)
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, Int, b)
forall a. Maybe a
Nothing ([Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b))
-> [Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b)
forall a b. (a -> b) -> a -> b
$ ((VName, Dependency) -> Maybe (VarType, Int, b))
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Dependency -> Maybe (VarType, Int, b)
forall {c}. c -> Dependency -> Maybe (VarType, Int, c)
f b
ia (Dependency -> Maybe (VarType, Int, b))
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> Maybe (VarType, Int, b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Maybe (VarType, Int, b)])
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> a -> b
$ Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> Map VName Dependency -> [(VName, Dependency)]
forall a b. (a -> b) -> a -> b
$ DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies DimAccess rep
a
          aggr2 :: Maybe (VarType, Int, b)
aggr2 =
            (Maybe (VarType, Int, b)
 -> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b))
-> Maybe (VarType, Int, b)
-> [Maybe (VarType, Int, b)]
-> Maybe (VarType, Int, b)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (VarType, Int, b)
-> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b)
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, Int, b)
forall a. Maybe a
Nothing ([Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b))
-> [Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b)
forall a b. (a -> b) -> a -> b
$ ((VName, Dependency) -> Maybe (VarType, Int, b))
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Dependency -> Maybe (VarType, Int, b)
forall {c}. c -> Dependency -> Maybe (VarType, Int, c)
f b
ib (Dependency -> Maybe (VarType, Int, b))
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> Maybe (VarType, Int, b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Maybe (VarType, Int, b)])
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> a -> b
$ Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> Map VName Dependency -> [(VName, Dependency)]
forall a b. (a -> b) -> a -> b
$ DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies DimAccess rep
b
      Maybe (VarType, Int, b) -> Maybe (VarType, Int, b) -> Ordering
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, Int, b)
aggr1 Maybe (VarType, Int, b)
aggr2
      where
        cmpIdxPat :: Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, a, b)
Nothing Maybe (VarType, a, b)
Nothing = Ordering
EQ
        cmpIdxPat (Just (VarType, a, b)
_) Maybe (VarType, a, b)
Nothing = Ordering
GT
        cmpIdxPat Maybe (VarType, a, b)
Nothing (Just (VarType, a, b)
_) = Ordering
LT
        cmpIdxPat
          (Just (VarType
iterL, a
lvlL, b
original_lvl_L))
          (Just (VarType
iterR, a
lvlR, b
original_lvl_R)) =
            case (VarType
iterL, VarType
iterR) of
              (VarType
ThreadID, VarType
ThreadID) -> (a
lvlL, b
original_lvl_L) (a, b) -> (a, b) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` (a
lvlR, b
original_lvl_R)
              (VarType
ThreadID, VarType
_) -> Ordering
LT
              (VarType
_, VarType
ThreadID) -> Ordering
GT
              (VarType, VarType)
_ -> (a
lvlL, b
original_lvl_L) (a, b) -> (a, b) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` (a
lvlR, b
original_lvl_R)

        max' :: Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, a, b)
lhs Maybe (VarType, a, b)
rhs =
          case Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, a, b)
lhs Maybe (VarType, a, b)
rhs of
            Ordering
LT -> Maybe (VarType, a, b)
rhs
            Ordering
_ -> Maybe (VarType, a, b)
lhs

        f :: c -> Dependency -> Maybe (VarType, Int, c)
f c
og (Dependency Int
lvl VarType
varType) = (VarType, Int, c) -> Maybe (VarType, Int, c)
forall a. a -> Maybe a
Just (VarType
varType, Int
lvl, c
og)

multicorePermutation :: PrimExpTable -> SegOpName -> ArrayName -> IndexExprName -> [DimAccess rep] -> Maybe Permutation
multicorePermutation :: forall {k} (rep :: k).
PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
multicorePermutation PrimExpTable
primExpTable SegOpName
_segOpName (VName
_arr_name, [BodyType]
nest, [Int]
arr_layout) VName
_idx_name [DimAccess rep]
dimAccesses = do
  -- Dont accept indices where the last index is invariant
  let lastIdxIsInvariant :: Bool
lastIdxIsInvariant = DimAccess rep -> Bool
forall {k} (rep :: k). DimAccess rep -> Bool
isInvariant (DimAccess rep -> Bool) -> DimAccess rep -> Bool
forall a b. (a -> b) -> a -> b
$ [DimAccess rep] -> DimAccess rep
forall a. HasCallStack => [a] -> a
last [DimAccess rep]
dimAccesses

  -- Check if any of the dependencies are too complex to reason about
  let dimAccesses' :: [DimAccess rep]
dimAccesses' = (DimAccess rep -> Bool) -> [DimAccess rep] -> [DimAccess rep]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> (DimAccess rep -> Maybe VName) -> DimAccess rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Maybe VName
forall {k} (rep :: k). DimAccess rep -> Maybe VName
originalVar) [DimAccess rep]
dimAccesses
      deps :: [VName]
deps = (DimAccess rep -> Maybe VName) -> [DimAccess rep] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimAccess rep -> Maybe VName
forall {k} (rep :: k). DimAccess rep -> Maybe VName
originalVar [DimAccess rep]
dimAccesses'
      counters :: [Bool]
counters = (DimAccess rep -> [Bool]) -> [DimAccess rep] -> [Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (((VName, Dependency) -> Bool) -> [(VName, Dependency)] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VarType -> Bool
isCounter (VarType -> Bool)
-> ((VName, Dependency) -> VarType) -> (VName, Dependency) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependency -> VarType
varType (Dependency -> VarType)
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> VarType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Bool])
-> (DimAccess rep -> [(VName, Dependency)])
-> DimAccess rep
-> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> (DimAccess rep -> Map VName Dependency)
-> DimAccess rep
-> [(VName, Dependency)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies) [DimAccess rep]
dimAccesses'
      primExps :: Maybe [PrimExp VName]
primExps = (VName -> Maybe (PrimExp VName))
-> [VName] -> Maybe [PrimExp VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Maybe (Maybe (PrimExp VName)) -> Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe (PrimExp VName)) -> Maybe (PrimExp VName))
-> (VName -> Maybe (Maybe (PrimExp VName)))
-> VName
-> Maybe (PrimExp VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimExpTable -> Maybe (Maybe (PrimExp VName))
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` PrimExpTable
primExpTable)) [VName]
deps
      inscrutable :: Bool
inscrutable = Bool -> ([PrimExp VName] -> Bool) -> Maybe [PrimExp VName] -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (((PrimExp VName, Bool) -> Bool) -> [(PrimExp VName, Bool)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((PrimExp VName -> Bool -> Bool) -> (PrimExp VName, Bool) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry PrimExp VName -> Bool -> Bool
isInscrutable) ([(PrimExp VName, Bool)] -> Bool)
-> ([PrimExp VName] -> [(PrimExp VName, Bool)])
-> [PrimExp VName]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PrimExp VName] -> [Bool] -> [(PrimExp VName, Bool)])
-> [Bool] -> [PrimExp VName] -> [(PrimExp VName, Bool)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [PrimExp VName] -> [Bool] -> [(PrimExp VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
counters) Maybe [PrimExp VName]
primExps

  -- Create a candidate permutation
  let perm :: [Int]
perm = ((Int, DimAccess rep) -> Int) -> [(Int, DimAccess rep)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimAccess rep) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimAccess rep)] -> [Int])
-> [(Int, DimAccess rep)] -> [Int]
forall a b. (a -> b) -> a -> b
$ [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
forall {k} (rep :: k).
[(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortMC ([Int] -> [DimAccess rep] -> [(Int, DimAccess rep)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
arr_layout [DimAccess rep]
dimAccesses)

  -- Check if we want to manifest this array with the permutation
  if Bool
lastIdxIsInvariant Bool -> Bool -> Bool
|| Bool
inscrutable Bool -> Bool -> Bool
|| [Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int]
perm [BodyType]
nest
    then Maybe [Int]
forall a. Maybe a
Nothing
    else [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm

instance Layout MC where
  permutationFromDimAccess :: PrimExpTable
-> SegOpName -> ArrayName -> VName -> [DimAccess MC] -> Maybe [Int]
permutationFromDimAccess = PrimExpTable
-> SegOpName -> ArrayName -> VName -> [DimAccess MC] -> Maybe [Int]
forall {k} (rep :: k).
PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
multicorePermutation

sortGPU :: [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortGPU :: forall {k} (rep :: k).
[(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortGPU =
  ((Int, DimAccess rep) -> (Int, DimAccess rep) -> Ordering)
-> [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (Int, DimAccess rep) -> (Int, DimAccess rep) -> Ordering
forall {k} {k} {b} {rep :: k} {rep :: k}.
Ord b =>
(b, DimAccess rep) -> (b, DimAccess rep) -> Ordering
dimdexGPUcmp
  where
    dimdexGPUcmp :: (b, DimAccess rep) -> (b, DimAccess rep) -> Ordering
dimdexGPUcmp (b
ia, DimAccess rep
a) (b
ib, DimAccess rep
b) = do
      let aggr1 :: Maybe (VarType, Int, b)
aggr1 =
            (Maybe (VarType, Int, b)
 -> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b))
-> Maybe (VarType, Int, b)
-> [Maybe (VarType, Int, b)]
-> Maybe (VarType, Int, b)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (VarType, Int, b)
-> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b)
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, Int, b)
forall a. Maybe a
Nothing ([Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b))
-> [Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b)
forall a b. (a -> b) -> a -> b
$ ((VName, Dependency) -> Maybe (VarType, Int, b))
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Dependency -> Maybe (VarType, Int, b)
forall {c}. c -> Dependency -> Maybe (VarType, Int, c)
f b
ia (Dependency -> Maybe (VarType, Int, b))
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> Maybe (VarType, Int, b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Maybe (VarType, Int, b)])
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> a -> b
$ Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> Map VName Dependency -> [(VName, Dependency)]
forall a b. (a -> b) -> a -> b
$ DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies DimAccess rep
a
          aggr2 :: Maybe (VarType, Int, b)
aggr2 =
            (Maybe (VarType, Int, b)
 -> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b))
-> Maybe (VarType, Int, b)
-> [Maybe (VarType, Int, b)]
-> Maybe (VarType, Int, b)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (VarType, Int, b)
-> Maybe (VarType, Int, b) -> Maybe (VarType, Int, b)
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, Int, b)
forall a. Maybe a
Nothing ([Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b))
-> [Maybe (VarType, Int, b)] -> Maybe (VarType, Int, b)
forall a b. (a -> b) -> a -> b
$ ((VName, Dependency) -> Maybe (VarType, Int, b))
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Dependency -> Maybe (VarType, Int, b)
forall {c}. c -> Dependency -> Maybe (VarType, Int, c)
f b
ib (Dependency -> Maybe (VarType, Int, b))
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> Maybe (VarType, Int, b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Maybe (VarType, Int, b)])
-> [(VName, Dependency)] -> [Maybe (VarType, Int, b)]
forall a b. (a -> b) -> a -> b
$ Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> Map VName Dependency -> [(VName, Dependency)]
forall a b. (a -> b) -> a -> b
$ DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies DimAccess rep
b
      Maybe (VarType, Int, b) -> Maybe (VarType, Int, b) -> Ordering
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, Int, b)
aggr1 Maybe (VarType, Int, b)
aggr2
      where
        cmpIdxPat :: Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, a, b)
Nothing Maybe (VarType, a, b)
Nothing = Ordering
EQ
        cmpIdxPat (Just (VarType, a, b)
_) Maybe (VarType, a, b)
Nothing = Ordering
GT
        cmpIdxPat Maybe (VarType, a, b)
Nothing (Just (VarType, a, b)
_) = Ordering
LT
        cmpIdxPat
          (Just (VarType
iterL, a
lvlL, b
original_lvl_L))
          (Just (VarType
iterR, a
lvlR, b
original_lvl_R)) = case (VarType
iterL, VarType
iterR) of
            (VarType
ThreadID, VarType
ThreadID) -> (a
lvlL, b
original_lvl_L) (a, b) -> (a, b) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` (a
lvlR, b
original_lvl_R)
            (VarType
ThreadID, VarType
_) -> Ordering
GT
            (VarType
_, VarType
ThreadID) -> Ordering
LT
            (VarType, VarType)
_ -> (a
lvlL, b
original_lvl_L) (a, b) -> (a, b) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` (a
lvlR, b
original_lvl_R)

        max' :: Maybe (VarType, a, b)
-> Maybe (VarType, a, b) -> Maybe (VarType, a, b)
max' Maybe (VarType, a, b)
lhs Maybe (VarType, a, b)
rhs =
          case Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
forall {a} {b}.
(Ord a, Ord b) =>
Maybe (VarType, a, b) -> Maybe (VarType, a, b) -> Ordering
cmpIdxPat Maybe (VarType, a, b)
lhs Maybe (VarType, a, b)
rhs of
            Ordering
LT -> Maybe (VarType, a, b)
rhs
            Ordering
_ -> Maybe (VarType, a, b)
lhs

        f :: c -> Dependency -> Maybe (VarType, Int, c)
f c
og (Dependency Int
lvl VarType
varType) = (VarType, Int, c) -> Maybe (VarType, Int, c)
forall a. a -> Maybe a
Just (VarType
varType, Int
lvl, c
og)

gpuPermutation :: PrimExpTable -> SegOpName -> ArrayName -> IndexExprName -> [DimAccess rep] -> Maybe Permutation
gpuPermutation :: forall {k} (rep :: k).
PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
gpuPermutation PrimExpTable
primExpTable SegOpName
_segOpName (VName
_arr_name, [BodyType]
nest, [Int]
arr_layout) VName
_idx_name [DimAccess rep]
dimAccesses = do
  -- Find the outermost parallel level. XXX: this is a bit hacky. Why
  -- don't we simply know at this point the nest in which this index
  -- occurs?
  let outermost_par :: Int
outermost_par = [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
mininum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (DimAccess rep -> [Int]) -> [DimAccess rep] -> [Int]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Dependency -> Int) -> [Dependency] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Dependency -> Int
lvl ([Dependency] -> [Int])
-> (DimAccess rep -> [Dependency]) -> DimAccess rep -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> [Dependency]
forall {k} {rep :: k}. DimAccess rep -> [Dependency]
parDeps) [DimAccess rep]
dimAccesses
      invariantToPar :: Dependency -> Bool
invariantToPar = (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
outermost_par) (Int -> Bool) -> (Dependency -> Int) -> Dependency -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependency -> Int
lvl

  -- Do nothing if last index is invariant to segop.
  let lastIdxIsInvariant :: Bool
lastIdxIsInvariant = (Dependency -> Bool) -> Map VName Dependency -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Dependency -> Bool
invariantToPar (Map VName Dependency -> Bool) -> Map VName Dependency -> Bool
forall a b. (a -> b) -> a -> b
$ DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies (DimAccess rep -> Map VName Dependency)
-> DimAccess rep -> Map VName Dependency
forall a b. (a -> b) -> a -> b
$ [DimAccess rep] -> DimAccess rep
forall a. HasCallStack => [a] -> a
last [DimAccess rep]
dimAccesses

  -- Do nothing if any index is constant, because otherwise we can end
  -- up transposing a too-large array.
  let anyIsConstant :: Bool
anyIsConstant = (DimAccess rep -> Bool) -> [DimAccess rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Map VName Dependency -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Map VName Dependency -> Bool)
-> (DimAccess rep -> Map VName Dependency) -> DimAccess rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies) [DimAccess rep]
dimAccesses

  -- Check if any of the dependencies are too complex to reason about
  let dimAccesses' :: [DimAccess rep]
dimAccesses' = (DimAccess rep -> Bool) -> [DimAccess rep] -> [DimAccess rep]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> (DimAccess rep -> Maybe VName) -> DimAccess rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Maybe VName
forall {k} (rep :: k). DimAccess rep -> Maybe VName
originalVar) [DimAccess rep]
dimAccesses
      deps :: [VName]
deps = (DimAccess rep -> Maybe VName) -> [DimAccess rep] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimAccess rep -> Maybe VName
forall {k} (rep :: k). DimAccess rep -> Maybe VName
originalVar [DimAccess rep]
dimAccesses'
      counters :: [Bool]
counters = (DimAccess rep -> [Bool]) -> [DimAccess rep] -> [Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (((VName, Dependency) -> Bool) -> [(VName, Dependency)] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VarType -> Bool
isCounter (VarType -> Bool)
-> ((VName, Dependency) -> VarType) -> (VName, Dependency) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependency -> VarType
varType (Dependency -> VarType)
-> ((VName, Dependency) -> Dependency)
-> (VName, Dependency)
-> VarType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Dependency) -> Dependency
forall a b. (a, b) -> b
snd) ([(VName, Dependency)] -> [Bool])
-> (DimAccess rep -> [(VName, Dependency)])
-> DimAccess rep
-> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName Dependency -> [(VName, Dependency)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Dependency -> [(VName, Dependency)])
-> (DimAccess rep -> Map VName Dependency)
-> DimAccess rep
-> [(VName, Dependency)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies) [DimAccess rep]
dimAccesses'
      primExps :: Maybe [PrimExp VName]
primExps = (VName -> Maybe (PrimExp VName))
-> [VName] -> Maybe [PrimExp VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Maybe (Maybe (PrimExp VName)) -> Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe (PrimExp VName)) -> Maybe (PrimExp VName))
-> (VName -> Maybe (Maybe (PrimExp VName)))
-> VName
-> Maybe (PrimExp VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimExpTable -> Maybe (Maybe (PrimExp VName))
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` PrimExpTable
primExpTable)) [VName]
deps
      inscrutable :: Bool
inscrutable = Bool -> ([PrimExp VName] -> Bool) -> Maybe [PrimExp VName] -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (((PrimExp VName, Bool) -> Bool) -> [(PrimExp VName, Bool)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((PrimExp VName -> Bool -> Bool) -> (PrimExp VName, Bool) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry PrimExp VName -> Bool -> Bool
isInscrutable) ([(PrimExp VName, Bool)] -> Bool)
-> ([PrimExp VName] -> [(PrimExp VName, Bool)])
-> [PrimExp VName]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PrimExp VName] -> [Bool] -> [(PrimExp VName, Bool)])
-> [Bool] -> [PrimExp VName] -> [(PrimExp VName, Bool)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [PrimExp VName] -> [Bool] -> [(PrimExp VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
counters) Maybe [PrimExp VName]
primExps

  -- Create a candidate permutation
  let perm :: [Int]
perm = ((Int, DimAccess rep) -> Int) -> [(Int, DimAccess rep)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimAccess rep) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimAccess rep)] -> [Int])
-> [(Int, DimAccess rep)] -> [Int]
forall a b. (a -> b) -> a -> b
$ [(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
forall {k} (rep :: k).
[(Int, DimAccess rep)] -> [(Int, DimAccess rep)]
sortGPU ([Int] -> [DimAccess rep] -> [(Int, DimAccess rep)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
arr_layout [DimAccess rep]
dimAccesses)

  -- Check if we want to manifest this array with the permutation
  if Bool
lastIdxIsInvariant
    Bool -> Bool -> Bool
|| Bool
anyIsConstant
    Bool -> Bool -> Bool
|| Bool
inscrutable
    Bool -> Bool -> Bool
|| [Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int]
perm [BodyType]
nest
    then Maybe [Int]
forall a. Maybe a
Nothing
    else [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
  where
    parDeps :: DimAccess rep -> [Dependency]
parDeps = (Dependency -> Bool) -> [Dependency] -> [Dependency]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VarType -> VarType -> Bool
forall a. Eq a => a -> a -> Bool
== VarType
ThreadID) (VarType -> Bool) -> (Dependency -> VarType) -> Dependency -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependency -> VarType
varType) ([Dependency] -> [Dependency])
-> (DimAccess rep -> [Dependency]) -> DimAccess rep -> [Dependency]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName Dependency -> [Dependency]
forall k a. Map k a -> [a]
M.elems (Map VName Dependency -> [Dependency])
-> (DimAccess rep -> Map VName Dependency)
-> DimAccess rep
-> [Dependency]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimAccess rep -> Map VName Dependency
forall {k} (rep :: k). DimAccess rep -> Map VName Dependency
dependencies

instance Layout GPU where
  permutationFromDimAccess :: PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess GPU]
-> Maybe [Int]
permutationFromDimAccess = PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess GPU]
-> Maybe [Int]
forall {k} (rep :: k).
PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
gpuPermutation

-- | like mapMaybe, but works on nested maps. Eliminates "dangling"
-- maps / rows with missing (Nothing) values.
tableMapMaybe ::
  (k0 -> k1 -> k2 -> a -> Maybe b) ->
  M.Map k0 (M.Map k1 (M.Map k2 a)) ->
  M.Map k0 (M.Map k1 (M.Map k2 b))
tableMapMaybe :: forall k0 k1 k2 a b.
(k0 -> k1 -> k2 -> a -> Maybe b)
-> Map k0 (Map k1 (Map k2 a)) -> Map k0 (Map k1 (Map k2 b))
tableMapMaybe k0 -> k1 -> k2 -> a -> Maybe b
f =
  (k0 -> Map k1 (Map k2 a) -> Maybe (Map k1 (Map k2 b)))
-> Map k0 (Map k1 (Map k2 a)) -> Map k0 (Map k1 (Map k2 b))
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey ((k0 -> Map k1 (Map k2 a) -> Maybe (Map k1 (Map k2 b)))
 -> Map k0 (Map k1 (Map k2 a)) -> Map k0 (Map k1 (Map k2 b)))
-> (k0 -> Map k1 (Map k2 a) -> Maybe (Map k1 (Map k2 b)))
-> Map k0 (Map k1 (Map k2 a))
-> Map k0 (Map k1 (Map k2 b))
forall a b. (a -> b) -> a -> b
$ \k0
key0 -> (k1 -> Map k2 a -> Maybe (Map k2 b))
-> Map k1 (Map k2 a) -> Maybe (Map k1 (Map k2 b))
forall {k} {a} {a}.
(k -> a -> Maybe a) -> Map k a -> Maybe (Map k a)
mapToMaybe ((k1 -> Map k2 a -> Maybe (Map k2 b))
 -> Map k1 (Map k2 a) -> Maybe (Map k1 (Map k2 b)))
-> (k1 -> Map k2 a -> Maybe (Map k2 b))
-> Map k1 (Map k2 a)
-> Maybe (Map k1 (Map k2 b))
forall a b. (a -> b) -> a -> b
$ (k2 -> a -> Maybe b) -> Map k2 a -> Maybe (Map k2 b)
forall {k} {a} {a}.
(k -> a -> Maybe a) -> Map k a -> Maybe (Map k a)
mapToMaybe ((k2 -> a -> Maybe b) -> Map k2 a -> Maybe (Map k2 b))
-> (k1 -> k2 -> a -> Maybe b) -> k1 -> Map k2 a -> Maybe (Map k2 b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k0 -> k1 -> k2 -> a -> Maybe b
f k0
key0
  where
    maybeMap :: M.Map k a -> Maybe (M.Map k a)
    maybeMap :: forall k a. Map k a -> Maybe (Map k a)
maybeMap Map k a
val = if Map k a -> Bool
forall a. Map k a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Map k a
val then Maybe (Map k a)
forall a. Maybe a
Nothing else Map k a -> Maybe (Map k a)
forall a. a -> Maybe a
Just Map k a
val

    mapToMaybe :: (k -> a -> Maybe a) -> Map k a -> Maybe (Map k a)
mapToMaybe k -> a -> Maybe a
g = Map k a -> Maybe (Map k a)
forall k a. Map k a -> Maybe (Map k a)
maybeMap (Map k a -> Maybe (Map k a))
-> (Map k a -> Map k a) -> Map k a -> Maybe (Map k a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k -> a -> Maybe a) -> Map k a -> Map k a
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey k -> a -> Maybe a
g

-- | Given an ordering function for `DimAccess`, and an IndexTable,
-- return a LayoutTable. We remove entries with no results after
-- `permutationFromDimAccess`
layoutTableFromIndexTable ::
  (Layout rep) =>
  PrimExpTable ->
  IndexTable rep ->
  LayoutTable
layoutTableFromIndexTable :: forall {k} (rep :: k).
Layout rep =>
PrimExpTable -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable = (SegOpName -> ArrayName -> VName -> [DimAccess rep] -> Maybe [Int])
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
-> LayoutTable
forall k0 k1 k2 a b.
(k0 -> k1 -> k2 -> a -> Maybe b)
-> Map k0 (Map k1 (Map k2 a)) -> Map k0 (Map k1 (Map k2 b))
tableMapMaybe ((SegOpName
  -> ArrayName -> VName -> [DimAccess rep] -> Maybe [Int])
 -> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
 -> LayoutTable)
-> (PrimExpTable
    -> SegOpName
    -> ArrayName
    -> VName
    -> [DimAccess rep]
    -> Maybe [Int])
-> PrimExpTable
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
-> LayoutTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
forall {k} (rep :: k).
Layout rep =>
PrimExpTable
-> SegOpName
-> ArrayName
-> VName
-> [DimAccess rep]
-> Maybe [Int]
permutationFromDimAccess