{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp (..),
    segLevel,
    segBody,
    segSpace,
    typeCheckSegOp,
    SegSpace (..),
    scopeOfSegSpace,
    segSpaceDims,

    -- * Details
    HistOp (..),
    histType,
    splitHistResults,
    SegBinOp (..),
    segBinOpResults,
    segBinOpChunks,
    KernelBody (..),
    aliasAnalyseKernelBody,
    consumedInKernelBody,
    ResultManifest (..),
    KernelResult (..),
    kernelResultCerts,
    kernelResultSubExp,

    -- ** Generic traversal
    SegOpMapper (..),
    identitySegOpMapper,
    mapSegOpM,
    traverseSegOpStms,

    -- * Simplification
    simplifySegOp,
    HasSegOp (..),
    segOpRules,

    -- * Memory
    segOpReturns,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.List
  ( elemIndex,
    foldl',
    groupBy,
    intersperse,
    isPrefixOf,
    partition,
    unzip4,
    zip4,
  )
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    CanBeAliased (..),
  )
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
  ( Doc,
    apply,
    hsep,
    parens,
    ppTuple',
    pretty,
    (<+>),
    (</>),
  )
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | An operator for 'SegHist'.
data HistOp rep = HistOp
  { forall rep. HistOp rep -> ShapeBase SubExp
histShape :: Shape,
    forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall rep. HistOp rep -> [VName]
histDest :: [VName],
    forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
    forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep) =>
(HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
compare :: HistOp rep -> HistOp rep -> Ordering
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
>= :: HistOp rep -> HistOp rep -> Bool
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
showsPrec :: Int -> HistOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
show :: HistOp rep -> String
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
showList :: [HistOp rep] -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp rep -> [Type]
histType :: forall rep. HistOp rep -> [Type]
histType HistOp rep
op =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
    Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$
      HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op

-- | Split reduction results returned by a 'KernelBody' into those
-- that correspond to indexes for the 'HistOp's, and those that
-- correspond to value.
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
  let ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp rep -> ShapeBase SubExp) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
      ([SubExp]
idxs, [SubExp]
vals) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
   in [[SubExp]] -> [[SubExp]] -> [([SubExp], [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp rep = SegBinOp
  { forall rep. SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
    forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
    forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
  }
  deriving (SegBinOp rep -> SegBinOp rep -> Bool
(SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool) -> Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
/= :: SegBinOp rep -> SegBinOp rep -> Bool
Eq, Eq (SegBinOp rep)
Eq (SegBinOp rep) =>
(SegBinOp rep -> SegBinOp rep -> Ordering)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> Ord (SegBinOp rep)
SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
Ord, Int -> SegBinOp rep -> ShowS
[SegBinOp rep] -> ShowS
SegBinOp rep -> String
(Int -> SegBinOp rep -> ShowS)
-> (SegBinOp rep -> String)
-> ([SegBinOp rep] -> ShowS)
-> Show (SegBinOp rep)
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
show :: SegBinOp rep -> String
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
showList :: [SegBinOp rep] -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall rep. [SegBinOp rep] -> Int
segBinOpResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody rep = KernelBody
  { forall rep. KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
    forall rep. KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
    forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
  }

deriving instance (RepTypes rep) => Ord (KernelBody rep)

deriving instance (RepTypes rep) => Show (KernelBody rep)

deriving instance (RepTypes rep) => Eq (KernelBody rep)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = -- | Don't simplify this one!
    ResultNoSimplify
  | -- | Go nuts.
    ResultMaySimplify
  | -- | The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
    ResultPrivate
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
/= :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ResultManifest -> ShowS
showsPrec :: Int -> ResultManifest -> ShowS
$cshow :: ResultManifest -> String
show :: ResultManifest -> String
$cshowList :: [ResultManifest] -> ShowS
showList :: [ResultManifest] -> ShowS
Show, Eq ResultManifest
Eq ResultManifest =>
(ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ResultManifest -> ResultManifest -> Ordering
compare :: ResultManifest -> ResultManifest -> Ordering
$c< :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
>= :: ResultManifest -> ResultManifest -> Bool
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
min :: ResultManifest -> ResultManifest -> ResultManifest
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult
  = -- | Each "worker" in the kernel returns this.
    -- Whether this is a result-per-thread or a
    -- result-per-block depends on where the 'SegOp' occurs.
    Returns ResultManifest Certs SubExp
  | WriteReturns
      Certs
      VName -- Destination array
      [(Slice SubExp, SubExp)]
  | TileReturns
      Certs
      [(SubExp, SubExp)] -- Total/tile for each dimension
      VName -- Tile written by this worker.
      -- The TileReturns must not expect more than one
      -- result to be written per physical thread.
  | RegTileReturns
      Certs
      -- For each dim of result:
      [ ( SubExp, -- size of this dim.
          SubExp, -- block tile size for this dim.
          SubExp -- reg tile size for this dim.
        )
      ]
      VName -- Tile returned by this thread/block.
  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
/= :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelResult -> ShowS
showsPrec :: Int -> KernelResult -> ShowS
$cshow :: KernelResult -> String
show :: KernelResult -> String
$cshowList :: [KernelResult] -> ShowS
showList :: [KernelResult] -> ShowS
Show, Eq KernelResult
Eq KernelResult =>
(KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelResult -> KernelResult -> Ordering
compare :: KernelResult -> KernelResult -> Ordering
$c< :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
>= :: KernelResult -> KernelResult -> Bool
$cmax :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
min :: KernelResult -> KernelResult -> KernelResult
Ord)

-- | Get the certs for this 'KernelResult'.
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance (ASTRep rep) => FreeIn (KernelBody rep) where
  freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where
      bound_in_stms :: Names
bound_in_stms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall rep. Stm rep -> Names
boundByStm Stms rep
stms

instance (ASTRep rep) => Substitute (KernelBody rep) where
  substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
      (Map VName VName -> BodyDec rep -> BodyDec rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
      (Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
      (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) =
    Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
      (Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance (ASTRep rep) => Rename (KernelBody rep) where
  rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
    BodyDec rep
dec' <- BodyDec rep -> RenameM (BodyDec rep)
forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
    Stms rep
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms ((Stms rep -> RenameM (KernelBody rep))
 -> RenameM (KernelBody rep))
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' ([KernelResult] -> KernelBody rep)
-> RenameM [KernelResult] -> RenameM (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody ::
  (Alias.AliasableRep rep) =>
  AliasTable ->
  KernelBody rep ->
  KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
   in BodyDec (Aliases rep)
-> Stms (Aliases rep) -> [KernelResult] -> KernelBody (Aliases rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody ::
  (Aliased rep) =>
  KernelBody rep ->
  Names
consumedInKernelBody :: forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where
    consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
    consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty

checkKernelBody ::
  (TC.Checkable rep) =>
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkKernelBody :: forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
  BodyDec rep -> TypeM rep ()
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
  -- We consume the kernel results (when applicable) before
  -- type-checking the stms, so we will get an error if a statement
  -- uses an array that is written to in a result.
  (KernelResult -> TypeM rep ()) -> [KernelResult] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
  Stms (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text
"Kernel return type is "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" values."
    (KernelResult -> Type -> TypeM rep ())
-> [KernelResult] -> [Type] -> TypeM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
  where
    consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
    consumeKernelResult KernelResult
_ =
      () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
    checkKernelResult (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text
"WriteReturns result type annotation for "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
arr
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" is "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
t
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but inferred as"
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
      [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
        Type -> Slice SubExp -> TypeM rep ()
forall rep. Checkable rep => Type -> Slice SubExp -> TypeM rep ()
TC.checkSlice Type
arr_t Slice SubExp
slice
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t Type -> ShapeBase SubExp -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Slice SubExp -> ShapeBase SubExp
forall d. Slice d -> ShapeBase d
sliceShape Slice SubExp
slice] SubExp
e
    checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
      Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            Text
"Invalid type for TileReturns " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
    checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles

      -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles)
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text
"Invalid type for TileReturns. Expected:\n  "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
expected
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n  "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
      where
        ([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
        expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)

kernelBodyMetrics :: (OpMetrics (Op rep)) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics = (Stm rep -> MetricsM ()) -> Seq (Stm rep) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics (Seq (Stm rep) -> MetricsM ())
-> (KernelBody rep -> Seq (Stm rep))
-> KernelBody rep
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

instance (PrettyRep rep) => Pretty (KernelBody rep) where
  pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.stack ((Stm rep -> Doc ann) -> [Stm rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Stm rep -> Doc ann
pretty (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commastack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc ann) -> [KernelResult] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelResult -> Doc ann
pretty [KernelResult]
res)

certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
  | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = []
  | Bool
otherwise = [Certs -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs]

instance Pretty KernelResult where
  pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
  pretty (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
      Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
        [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
arr Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.apply (((Slice SubExp, SubExp) -> Doc ann)
-> [(Slice SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)]
    where
      ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
slice Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
e
  pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp) -> Doc ann) -> [(SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
tile
  pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp, SubExp) -> Doc ann)
-> [(SubExp, SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc ann
forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
        a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace
  { -- | Flat physical index corresponding to the
    -- dimensions (at code generation used for a
    -- thread ID or similar).
    SegSpace -> VName
segFlat :: VName,
    SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
  }
  deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
/= :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace =>
(SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegSpace -> SegSpace -> Ordering
compare :: SegSpace -> SegSpace -> Ordering
$c< :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
>= :: SegSpace -> SegSpace -> Bool
$cmax :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
min :: SegSpace -> SegSpace -> SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegSpace -> ShowS
showsPrec :: Int -> SegSpace -> ShowS
$cshow :: SegSpace -> String
show :: SegSpace -> String
$cshowList :: [SegSpace] -> ShowS
showList :: [SegSpace] -> ShowS
Show)

-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Map VName (NameInfo rep))
-> [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, NameInfo rep))
-> [VName] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (,IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space)

checkSegSpace :: (TC.Checkable rep) => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM rep ())
-> [(VName, SubExp)] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the block-level.
--
-- The type list is usually the type of the element returned by a
-- single thread. The result of the SegOp is then an array of that
-- type, with the shape of the 'SegSpace' prepended. One exception is
-- for 'WriteReturns', where the type annotation is the /full/ type of
-- the result.
data SegOp lvl rep
  = SegMap lvl SegSpace [Type] (KernelBody rep)
  | -- | The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
    SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
  deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
(SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool) -> Eq (SegOp lvl rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, Eq (SegOp lvl rep)
Eq (SegOp lvl rep) =>
(SegOp lvl rep -> SegOp lvl rep -> Ordering)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> Ord (SegOp lvl rep)
SegOp lvl rep -> SegOp lvl rep -> Bool
SegOp lvl rep -> SegOp lvl rep -> Ordering
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
Ord, Int -> SegOp lvl rep -> ShowS
[SegOp lvl rep] -> ShowS
SegOp lvl rep -> String
(Int -> SegOp lvl rep -> ShowS)
-> (SegOp lvl rep -> String)
-> ([SegOp lvl rep] -> ShowS)
-> Show (SegOp lvl rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
show :: SegOp lvl rep -> String
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
showList :: [SegOp lvl rep] -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall lvl rep. SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall lvl rep. SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl

-- | The body of a 'SegOp'.
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
  case SegOp lvl rep
segop of
    SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns {}) =
  Type
t
segResultShape SegSpace
space Type
t Returns {} =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall lvl rep. SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
  [Type]
red_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
reds
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
  [Type]
scan_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
    scan_ts :: [Type]
scan_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
scans
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
  HistOp rep
op <- [HistOp rep]
ops
  let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
  where
    dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl) where
  opType :: forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl rep -> [ExtType]) -> SegOp lvl rep -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

instance (ASTConstraints lvl) => AliasedOp (SegOp lvl) where
  opAliases :: forall rep. Aliased rep => SegOp lvl rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

  consumedInOp :: forall rep. Aliased rep => SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
    [VName] -> Names
namesFromList ((HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp ::
  (TC.Checkable rep) =>
  (lvl -> TC.TypeM rep ()) ->
  SegOp lvl (Aliases rep) ->
  TC.TypeM rep ()
typeCheckSegOp :: forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
  where
    reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
  where
    scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]])
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
      (SubExp -> TypeM rep ()) -> ShapeBase SubExp -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            Text
"SegHist operator has return type "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape' :: ShapeBase SubExp
dest_shape' = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
      [(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
        Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

      [Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t =
          (HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp (Aliases rep) -> ShapeBase SubExp)
-> HistOp (Aliases rep)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
            [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"SegHist body has return type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed ::
  (TC.Checkable rep) =>
  SegSpace ->
  [(Lambda (Aliases rep), [SubExp], Shape)] ->
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkScanRed :: forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
    -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops (((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
  -> TypeM rep [Type])
 -> TypeM rep [[Type]])
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
    -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."

      [Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"Wrong return for body (does not match neutral elements; expected "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
expecting
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; found "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
got
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl frep trep m = SegOpMapper
  { forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: (Monad m) => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper =
  SegOpMapper
    { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

mapOnSegSpace ::
  (Monad f) => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace
    (VName -> [(VName, SubExp)] -> SegSpace)
-> f VName -> f ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
    f ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((VName -> f VName)
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (SegOpMapper lvl frep trep f -> SubExp -> f SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims

mapSegBinOp ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  SegBinOp frep ->
  m (SegBinOp trep)
mapSegBinOp :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
  Commutativity
-> Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
    (Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m (Lambda trep)
-> m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
    m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
    m (ShapeBase SubExp -> SegBinOp trep)
-> m (ShapeBase SubExp) -> m (SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  SegOp lvl frep ->
  m (SegOp lvl trep)
mapSegOpM :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
  lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
    (lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace -> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> Type -> m Type
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
    (lvl
 -> SegSpace
 -> [HistOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [HistOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
  where
    onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
      ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
        (ShapeBase SubExp
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> ShapeBase SubExp
 -> Lambda trep
 -> HistOp trep)
-> m (ShapeBase SubExp)
-> m (SubExp
      -> [VName]
      -> [SubExp]
      -> ShapeBase SubExp
      -> Lambda trep
      -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
        m (SubExp
   -> [VName]
   -> [SubExp]
   -> ShapeBase SubExp
   -> Lambda trep
   -> HistOp trep)
-> m SubExp
-> m ([VName]
      -> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
        m ([VName]
   -> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [VName]
-> m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
        m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
        m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m (ShapeBase SubExp) -> m (Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
        m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op

mapOnSegOpType ::
  (Monad m) =>
  SegOpMapper lvl frep trep m ->
  Type ->
  m Type
mapOnSegOpType :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
  VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
    (VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m VName
-> m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
    m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m ([Type] -> NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
    m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((ShapeBase SubExp -> m (ShapeBase SubExp))
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
    m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
  PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (ShapeBase SubExp -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

rephraseBinOp ::
  (Monad f) =>
  Rephraser f from rep ->
  SegBinOp from ->
  f (SegBinOp rep)
rephraseBinOp :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser f from rep
r (SegBinOp Commutativity
comm Lambda from
lam [SubExp]
nes ShapeBase SubExp
shape) =
  Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f (Lambda rep)
-> f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> Lambda from -> f (Lambda rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser f from rep
r Lambda from
lam f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f [SubExp] -> f (ShapeBase SubExp -> SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes f (ShapeBase SubExp -> SegBinOp rep)
-> f (ShapeBase SubExp) -> f (SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase SubExp -> f (ShapeBase SubExp)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase SubExp
shape

rephraseKernelBody ::
  (Monad f) =>
  Rephraser f from rep ->
  KernelBody from ->
  f (KernelBody rep)
rephraseKernelBody :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser f from rep
r (KernelBody BodyDec from
dec Stms from
stms [KernelResult]
res) =
  BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep)
-> f (BodyDec rep)
-> f (Stms rep -> [KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> BodyDec from -> f (BodyDec rep)
forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser f from rep
r BodyDec from
dec f (Stms rep -> [KernelResult] -> KernelBody rep)
-> f (Stms rep) -> f ([KernelResult] -> KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Stm from -> f (Stm rep)) -> Stms from -> f (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse (Rephraser f from rep -> Stm from -> f (Stm rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser f from rep
r) Stms from
stms f ([KernelResult] -> KernelBody rep)
-> f [KernelResult] -> f (KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance RephraseOp (SegOp lvl) where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp lvl from -> m (SegOp lvl to)
rephraseInOp Rephraser m from to
r (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody from
body) =
    lvl -> SegSpace -> [Type] -> KernelBody to -> SegOp lvl to
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegRed lvl
lvl SegSpace
space [SegBinOp from]
reds [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space
      ([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
reds
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegScan lvl
lvl SegSpace
space [SegBinOp from]
scans [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space
      ([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
scans
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegHist lvl
lvl SegSpace
space [HistOp from]
hists [Type]
ts KernelBody from
body) =
    lvl
-> SegSpace
-> [HistOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space
      ([HistOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [HistOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp from -> m (HistOp to)) -> [HistOp from] -> m [HistOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
hists
      m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
    where
      onOp :: HistOp from -> m (HistOp to)
onOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda from
op) =
        ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda to
-> HistOp to
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape (Lambda to -> HistOp to) -> m (Lambda to) -> m (HistOp to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op

-- | A helper for defining 'TraverseOpStms'.
traverseSegOpStms :: (Monad m) => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = SegOpMapper lvl rep rep m -> SegOp lvl rep -> m (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
  where
    seg_scope :: Scope rep
seg_scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
    f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
    mapper :: SegOpMapper lvl rep rep m
mapper =
      SegOpMapper lvl Any Any m
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda = traverseLambdaStms f',
          mapOnSegOpBody = onBody
        }
    onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (Stms rep -> [KernelResult] -> KernelBody rep)
-> m (Stms rep) -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance
  (ASTRep rep, Substitute lvl) =>
  Substitute (SegOp lvl rep)
  where
  substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl rep -> Identity (SegOp lvl rep))
-> SegOp lvl rep
-> SegOp lvl rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep rep Identity
-> SegOp lvl rep -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
    where
      substitute :: SegOpMapper lvl rep rep Identity
substitute =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> Identity (KernelBody rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody rep -> KernelBody rep)
-> KernelBody rep
-> Identity (KernelBody rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody rep -> KernelBody rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
  rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
    [VName] -> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a. [VName] -> RenameM a -> RenameM a
renameBound (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) (RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep))
-> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep RenameM
-> SegOp lvl rep -> RenameM (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
    where
      renamer :: SegOpMapper lvl rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (KernelBody rep -> RenameM (KernelBody rep))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl rep rep RenameM
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename KernelBody rep -> RenameM (KernelBody rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename

instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
  freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
    Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$
      (State FV (SegOp lvl rep) -> FV -> FV)
-> FV -> State FV (SegOp lvl rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl rep) -> FV) -> State FV (SegOp lvl rep) -> FV
forall a b. (a -> b) -> a -> b
$
        SegOpMapper lvl rep rep (StateT FV Identity)
-> SegOp lvl rep -> State FV (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = (KernelBody rep -> FV)
-> KernelBody rep -> StateT FV Identity (KernelBody rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance (OpMetrics (Op rep)) => OpMetrics (SegOp lvl rep) where
  opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Text -> MetricsM () -> MetricsM ()
inside Text
"SegBinOp" (MetricsM () -> MetricsM ())
-> (SegBinOp rep -> MetricsM ()) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Text -> MetricsM () -> MetricsM ()
inside Text
"SegBinOp" (MetricsM () -> MetricsM ())
-> (SegBinOp rep -> MetricsM ()) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body

instance Pretty SegSpace where
  pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
    [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply
      ( do
          (VName
i, SubExp
d) <- [(VName, SubExp)]
dims
          Doc ann -> [Doc ann]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Doc ann -> [Doc ann]) -> Doc ann -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
i Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
d
      )
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
phys)

instance (PrettyRep rep) => Pretty (SegBinOp rep) where
  pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm'
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
    where
      comm' :: Doc ann
comm' = case Commutativity
comm of
        Commutativity
Commutative -> Doc ann
"commutative "
        Commutativity
Noncommutative -> Doc ann
forall a. Monoid a => a
mempty

instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
  pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
    Doc ann
"segmap"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
forall ann. Doc ann
PP.colon
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
    Doc ann
"segred"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
reds)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
    Doc ann
"segscan"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
scans)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
  pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
    Doc ann
"seghist"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc ann) -> [HistOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc ann
forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
    where
      ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
        ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
w
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
rf
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
dests)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
          Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
            Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
op

instance CanBeAliased (SegOp lvl) where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SegOp lvl rep -> SegOp lvl (Aliases rep)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Aliases rep)))
-> SegOp lvl rep
-> SegOp lvl (Aliases rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Aliases rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
    where
      alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Aliases rep)))
-> (KernelBody rep -> Identity (KernelBody (Aliases rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Aliases rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep)))
-> (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep
-> Identity (Lambda (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
          (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep)))
-> (KernelBody rep -> KernelBody (Aliases rep))
-> KernelBody rep
-> Identity (KernelBody (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
          VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

informKernelBody :: (Informing rep) => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (Stms rep -> Stms (Wise rep)
forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res

instance CanBeWise (SegOp lvl) where
  addOpWisdom :: forall rep. Informing rep => SegOp lvl rep -> SegOp lvl (Wise rep)
addOpWisdom = Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Wise rep)))
-> SegOp lvl rep
-> SegOp lvl (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Wise rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Wise rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Wise rep) Identity
forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
    where
      add :: SegOpMapper lvl rep (Wise rep) Identity
add =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (KernelBody rep -> Identity (KernelBody (Wise rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Wise rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
          (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep)))
-> (KernelBody rep -> KernelBody (Wise rep))
-> KernelBody rep
-> Identity (KernelBody (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> KernelBody (Wise rep)
forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
          VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (ASTRep rep) => ST.IndexOp (SegOp lvl rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
    Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed Certs
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm rep -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm rep) -> Map VName Indexed
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm rep) -> Map VName Indexed)
-> Seq (Stm rep) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      -- Indexes in excess of what is used to index through the
      -- segment dimensions.
      excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is

      expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
            VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          BasicOp (Index VName
arr Slice SubExp
slice) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
          VName
arr VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
          Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
            let idx :: Indexed
idx =
                  Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
                    (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs)
                    VName
arr
                    (Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((PrimExp VName -> TPrimExp Int64 VName)
-> Slice (PrimExp VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
             in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
        | Bool
otherwise =
            Map VName Indexed
table

      asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
        WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (Slice (PrimExp VName))
 -> Maybe (Slice (PrimExp VName), Certs))
-> (Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName)))
-> Slice SubExp
-> Maybe (Slice (PrimExp VName), Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SubExp -> WriterT Certs Maybe (PrimExp VName))
-> Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Slice a -> f (Slice b)
traverse ((VName -> WriterT Certs Maybe (PrimExp VName))
-> SubExp -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))

      asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
        | Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall a b.
WriterT Certs Maybe a
-> WriterT Certs Maybe b -> WriterT Certs Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m a -> WriterT Certs m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (ASTConstraints lvl) => IsOp (SegOp lvl) where
  cheapOp :: forall rep. ASTRep rep => SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
  safeOp :: forall rep. ASTRep rep => SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True
  opDependencies :: forall rep. ASTRep rep => SegOp lvl rep -> [Names]
opDependencies SegOp lvl rep
op = Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType SegOp lvl rep
op)) (SegOp lvl rep -> Names
forall a. FreeIn a => a -> Names
freeIn SegOp lvl rep
op)

--- Simplification

instance Engine.Simplifiable SegSpace where
  simplify :: forall rep. SimplifiableRep rep => SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM rep [(VName, SubExp)] -> SimpleM rep SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM rep (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM rep [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> SimpleM rep SubExp)
-> (VName, SubExp) -> SimpleM rep (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> (VName, a) -> f (VName, b)
traverse SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Certs -> SubExp -> KernelResult)
-> SimpleM rep Certs -> SimpleM rep (SubExp -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep (SubExp -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
  simplify (WriteReturns Certs
cs VName
a [(Slice SubExp, SubExp)]
res) =
    Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
      (Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep VName
-> SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
      SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep [(Slice SubExp, SubExp)] -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Certs -> [(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
  simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
      SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what

mkWiseKernelBody ::
  (Informing rep) =>
  BodyDec rep ->
  Stms (Wise rep) ->
  [KernelResult] ->
  KernelBody (Wise rep)
mkWiseKernelBody :: forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
  let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms (Result -> Body (Wise rep)) -> Result -> Body (Wise rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
   in BodyDec (Wise rep)
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
  where
    res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM ::
  (MonadBuilder m) =>
  Stms (Rep m) ->
  [KernelResult] ->
  m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
  Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
  KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Rep m) -> m (KernelBody (Rep m)))
-> KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
  where
    res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody ::
  (Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
  SegSpace ->
  KernelBody (Wise rep) ->
  Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
  BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers

  let blocker :: BlockPred (Wise rep)
blocker =
        Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isOp
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isConsumed
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated

  -- Ensure we do not try to use anything that is consumed in the result.
  ([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True})
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
      (SimpleM rep ([KernelResult], UsageTable)
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
        [KernelResult]
res' <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Wise rep)) -> [VName])
-> Map VName (NameInfo (Wise rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) (SimpleM rep [KernelResult] -> SimpleM rep [KernelResult])
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall a b. (a -> b) -> a -> b
$
            (KernelResult -> SimpleM rep KernelResult)
-> [KernelResult] -> SimpleM rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelResult -> SimpleM rep KernelResult
forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
        ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], UsageTable)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')

  (KernelBody (Wise rep), Stms (Wise rep))
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
  where
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = SegSpace -> SymbolTable (Wise rep)
forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
    bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space

    consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      [VName
arr]
    consumedInResult KernelResult
_ =
      []

simplifyLambda ::
  (Engine.SimplifiableRep rep) =>
  Names ->
  Lambda (Wise rep) ->
  Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
bound = SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> (Lambda (Wise rep)
    -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
bound

segSpaceSymbolTable :: (ASTRep rep) => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  (SymbolTable rep -> (VName, SubExp) -> SymbolTable rep)
-> SymbolTable rep -> [(VName, SubExp)] -> SymbolTable rep
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
forall {rep}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope rep -> SymbolTable rep) -> Scope rep -> SymbolTable rep
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo rep -> Scope rep
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo rep -> Scope rep) -> NameInfo rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
  where
    f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable

simplifySegBinOp ::
  (Engine.SimplifiableRep rep) =>
  VName ->
  SegBinOp (Wise rep) ->
  Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp VName
phys_id (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName VName
phys_id) Lambda (Wise rep)
lam
  ShapeBase SubExp
shape' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
  [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  (SegBinOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity
-> Lambda (Wise rep)
-> [SubExp]
-> ShapeBase SubExp
-> SegBinOp (Wise rep)
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ (),
    Engine.Simplifiable lvl
  ) =>
  SegOp lvl (Wise rep) ->
  Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
      Stms (Wise rep)
body_hoisted
    )
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
reds
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
scans
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume) ([VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ (HistOp (Wise rep) -> [VName]) -> [HistOp (Wise rep)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Wise rep) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Wise rep)]
ops) (SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
 -> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep)))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
    ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
 -> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> ((HistOp (Wise rep)
     -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
    -> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [HistOp (Wise rep)]
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
  -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      \(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
        ShapeBase SubExp
w' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
        SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
        [VName]
arrs' <- [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
        [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
        ShapeBase SubExp
dims' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
        (Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
            (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
              Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName (SegSpace -> VName
segFlat SegSpace
space)) Lambda (Wise rep)
lam
        (HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
            Stms (Wise rep)
op_hoisted
          )

    (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

    (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( lvl
-> SegSpace
-> [HistOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
        [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
      )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope

-- | Does this rep contain 'SegOp's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSegOp rep where
  type SegOpLevel rep
  asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
  segOp :: SegOp (SegOpLevel rep) rep -> Op rep

-- | Simplification rules for simplifying 'SegOp's.
segOpRules ::
  (HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
  RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
  [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (TopDown rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [RuleOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (BottomUp rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]

segOpRuleTopDown ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      TopDown rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

segOpRuleBottomUp ::
  (HasSegOp rep, BuilderOps rep, Aliased rep) =>
  BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      BottomUp rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

topDownSegOp ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  ST.SymbolTable rep ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElem (LetDec rep), KernelResult)]
 -> ([Type], [PatElem (LetDec rep)], [KernelResult]))
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep ([Type], [PatElem (LetDec rep)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool)
-> [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult ([Type]
-> [PatElem (LetDec rep)]
-> [KernelResult]
-> [(Type, PatElem (LetDec rep), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody <- Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
  where
    isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
    isInvariant (Var VName
v) = Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable

    checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
      | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
        ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
        SubExp -> Bool
isInvariant SubExp
se = do
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
              ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
          Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
      Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more shared memory usage.
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
  | [SegBinOp rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
      ((SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> Bool)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool
forall {rep} {b} {rep} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> [[(SegBinOp rep,
       [(PatElem (LetDec rep), Type, KernelResult)])]])
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
        [SegBinOp rep]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops ([[(PatElem (LetDec rep), Type, KernelResult)]]
 -> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])])
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) ([(PatElem (LetDec rep), Type, KernelResult)]
 -> [[(PatElem (LetDec rep), Type, KernelResult)]])
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
            [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Bool)
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
    -> Int)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> ([SegBinOp rep],
     [[(PatElem (LetDec rep), Type, KernelResult)]]))
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Maybe
      (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)]))
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
     (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
forall {rep} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
          ([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
          pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult = red_res' ++ map_res}
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
  where
    ([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
    ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
    ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody

    sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) =
      SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2
        Bool -> Bool -> Bool
&& ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0

    combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = Maybe (SegBinOp rep, [a])
forall a. Maybe a
Nothing
    combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a. a -> Maybe a
Just ((SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a]))
-> (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a]))
-> (SegBinOp rep, [a])
-> [(SegBinOp rep, [a])]
-> (SegBinOp rep, [a])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
forall {rep} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs

    combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
      let lam1 :: Lambda rep
lam1 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
          lam2 :: Lambda rep
lam2 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
          ([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
          ([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
          lam :: Lambda rep
lam =
            Lambda
              { lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
                  [Param (LParamInfo rep)]
op1_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
                lambdaReturnType :: [Type]
lambdaReturnType = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
                lambdaBody :: Body rep
lambdaBody =
                  Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
                    Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
              }
       in ( SegBinOp
              { segBinOpComm :: Commutativity
segBinOpComm = SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
                segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
                segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
                segBinOpShape :: ShapeBase SubExp
segBinOpShape = SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 -- Same as shape of op2 due to the grouping.
              },
            [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
          )
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = Rule rep
forall rep. Rule rep
Skip

-- A convenient way of operating on the type and body of a SegOp,
-- without worrying about exactly what kind it is.
segOpGuts ::
  SegOp (SegOpLevel rep) rep ->
  ( [Type],
    KernelBody rep,
    Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
  )
segOpGuts :: forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, Int
0, SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)

bottomUpSegOp ::
  (Aliased rep, HasSegOp rep, BuilderOps rep) =>
  (ST.SymbolTable rep, UT.UsageTable) ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
_vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop
  -- Remove dead results. This is a bit tricky to do with scan/red
  -- results, so we only deal with map results for now.
  | ([Int]
_, [PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres') <- [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(Int, PatElem (LetDec rep), Type, KernelResult)]
 -> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ ((Int, PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int, PatElem (LetDec rep), Type, KernelResult) -> Bool
keep ([(Int, PatElem (LetDec rep), Type, KernelResult)]
 -> [(Int, PatElem (LetDec rep), Type, KernelResult)])
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] [PatElem (LetDec rep)]
kpes [Type]
kts [KernelResult]
kres,
    [PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElem (LetDec rep)]
kpes = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      KernelBody rep
kbody' <- Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
      Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> OpC rep rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> OpC rep rep)
-> SegOp (SegOpLevel rep) rep -> OpC rep rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
  where
    space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
    ([Type]
kts, KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
      SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop

    keep :: (Int, PatElem (LetDec rep), Type, KernelResult) -> Bool
keep (Int
i, PatElem (LetDec rep)
pe, Type
_, KernelResult
_) =
      Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_nonmap_results Bool -> Bool -> Bool
|| PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
used
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
_used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.  We have to be careful
  -- not to remove anything that is passed on to a scan/map/histogram
  -- operation.  Fortunately, these are always first in the result
  -- list.
  ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
    Scope rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM
   rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
      (([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> Stm rep
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stms rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, Stms rep
forall a. Monoid a => a
mempty) Stms rep
kstms

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody' <-
    Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms' [KernelResult]
kres'

  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> OpC rep rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> OpC rep rep)
-> SegOp (SegOpLevel rep) rep -> OpC rep rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
  where
    ([Type]
kts, KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
      SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
    free_in_kstms :: Names
free_in_kstms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
    space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop

    sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
      | Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
aux (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
        [DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        [DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        Slice SubExp
remaining_slice <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool)
-> (VName -> Maybe (Entry rep)) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable rep -> Maybe (Entry rep))
-> SymbolTable rep -> VName -> Maybe (Entry rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Certs -> Names
forall a. FreeIn a => a -> Names
freeIn (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) =
          (Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
      | Bool
otherwise =
          Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
      | Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
        Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
        Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
          let outer_slice :: [DimIndex SubExp]
outer_slice =
                (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map
                  ( \SubExp
d ->
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
                  )
                  ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
                [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] (Exp rep -> RuleM rep ())
-> (Slice SubExp -> Exp rep) -> Slice SubExp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM rep ()) -> Slice SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
                  [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
          VName
precopy <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
          PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName = precopy}
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
precopy
          ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( [PatElem (LetDec rep)]
kpes'',
              [Type]
kts'',
              [KernelResult]
kres'',
              if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                then Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
                else Stms rep
kstms'
            )
    distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
      ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)

    isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
      case ((PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([(PatElem (LetDec rep), Type, KernelResult)],
     [(PatElem (LetDec rep), Type, KernelResult)]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
        ([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
          | Just Int
i <- PatElem (LetDec rep) -> [PatElem (LetDec rep)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
            Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
            ([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
              (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
 [KernelResult])
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
        ([(PatElem (LetDec rep), Type, KernelResult)],
 [(PatElem (LetDec rep), Type, KernelResult)])
_ -> Maybe
  (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
   [KernelResult])
forall a. Maybe a
Nothing
      where
        matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
        matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False

--- Memory

kernelBodyReturns ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  KernelBody somerep ->
  [ExpReturns] ->
  m [ExpReturns]
kernelBodyReturns :: forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall {rep} {inner :: * -> *} {m :: * -> *}.
(BranchType rep ~ BranchTypeMem, LParamInfo rep ~ LParamMem,
 FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, Monad m, HasScope rep m,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
 RephraseOp inner, Pretty (inner rep), Rename (inner rep),
 Substitute (inner rep), FreeIn (inner rep), Show (inner rep),
 Ord (inner rep)) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody somerep -> [KernelResult])
-> KernelBody somerep
-> [ExpReturns]
-> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody somerep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
  where
    correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
    correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns ::
  (Mem rep inner, Monad m, HasScope rep m) =>
  SegOp lvl rep ->
  m [ExpReturns]
segOpReturns :: forall rep (inner :: * -> *) (m :: * -> *) lvl.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl rep -> m [ExpReturns]
segOpReturns k :: SegOp lvl rep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
  KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns k :: SegOp lvl rep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
  KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns k :: SegOp lvl rep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
  KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
_) =
  [[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp rep -> m [ExpReturns]) -> [HistOp rep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp rep -> [VName]) -> HistOp rep -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops