{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}

-- | High-level representation of SOACs.  When performing
-- SOAC-transformations, operating on normal 'Exp' values is somewhat
-- of a nuisance, as they can represent terms that are not proper
-- SOACs.  In contrast, this module exposes a SOAC representation that
-- does not enable invalid representations (except for type errors).
--
-- Furthermore, while standard normalised Futhark requires that the inputs
-- to a SOAC are variables or constants, the representation in this
-- module also supports various index-space transformations, like
-- @replicate@ or @rearrange@.  This is also very convenient when
-- implementing transformations.
--
-- The names exported by this module conflict with the standard Futhark
-- syntax tree constructors, so you are advised to use a qualified
-- import:
--
-- @
-- import Futhark.Analysis.HORep.SOAC (SOAC)
-- import qualified Futhark.Analysis.HORep.SOAC as SOAC
-- @
module Futhark.Analysis.HORep.SOAC
  ( -- * SOACs
    SOAC (..),
    Futhark.ScremaForm (..),
    inputs,
    setInputs,
    lambda,
    setLambda,
    typeOf,
    width,

    -- ** Converting to and from expressions
    NotSOAC (..),
    fromExp,
    toExp,
    toSOAC,

    -- * SOAC inputs
    Input (..),
    varInput,
    identInput,
    isVarInput,
    isVarishInput,
    addTransform,
    addInitialTransforms,
    inputArray,
    inputRank,
    inputType,
    inputRowType,
    transformRows,
    transposeInput,

    -- ** Input transformations
    ArrayTransforms,
    noTransforms,
    nullTransforms,
    (|>),
    (<|),
    viewf,
    ViewF (..),
    viewl,
    ViewL (..),
    ArrayTransform (..),
    transformFromExp,
    soacToStream,
  )
where

import Data.Foldable as Foldable
import Data.Maybe
import qualified Data.Sequence as Seq
import Futhark.Construct hiding (toExp)
import Futhark.IR hiding
  ( Iota,
    Rearrange,
    Replicate,
    Reshape,
    Var,
    typeOf,
  )
import qualified Futhark.IR as Futhark
import Futhark.IR.SOACS.SOAC
  ( HistOp (..),
    ScremaForm (..),
    StreamForm (..),
    StreamOrd (..),
    scremaType,
  )
import qualified Futhark.IR.SOACS.SOAC as Futhark
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util.Pretty (ppr, text)
import qualified Futhark.Util.Pretty as PP

-- | A single, simple transformation.  If you want several, don't just
-- create a list, use 'ArrayTransforms' instead.
data ArrayTransform
  = -- | A permutation of an otherwise valid input.
    Rearrange Certificates [Int]
  | -- | A reshaping of an otherwise valid input.
    Reshape Certificates (ShapeChange SubExp)
  | -- | A reshaping of the outer dimension.
    ReshapeOuter Certificates (ShapeChange SubExp)
  | -- | A reshaping of everything but the outer dimension.
    ReshapeInner Certificates (ShapeChange SubExp)
  | -- | Replicate the rows of the array a number of times.
    Replicate Certificates Shape
  deriving (Int -> ArrayTransform -> ShowS
[ArrayTransform] -> ShowS
ArrayTransform -> String
(Int -> ArrayTransform -> ShowS)
-> (ArrayTransform -> String)
-> ([ArrayTransform] -> ShowS)
-> Show ArrayTransform
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransform] -> ShowS
$cshowList :: [ArrayTransform] -> ShowS
show :: ArrayTransform -> String
$cshow :: ArrayTransform -> String
showsPrec :: Int -> ArrayTransform -> ShowS
$cshowsPrec :: Int -> ArrayTransform -> ShowS
Show, ArrayTransform -> ArrayTransform -> Bool
(ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool) -> Eq ArrayTransform
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransform -> ArrayTransform -> Bool
$c/= :: ArrayTransform -> ArrayTransform -> Bool
== :: ArrayTransform -> ArrayTransform -> Bool
$c== :: ArrayTransform -> ArrayTransform -> Bool
Eq, Eq ArrayTransform
Eq ArrayTransform
-> (ArrayTransform -> ArrayTransform -> Ordering)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> ArrayTransform)
-> (ArrayTransform -> ArrayTransform -> ArrayTransform)
-> Ord ArrayTransform
ArrayTransform -> ArrayTransform -> Bool
ArrayTransform -> ArrayTransform -> Ordering
ArrayTransform -> ArrayTransform -> ArrayTransform
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmin :: ArrayTransform -> ArrayTransform -> ArrayTransform
max :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmax :: ArrayTransform -> ArrayTransform -> ArrayTransform
>= :: ArrayTransform -> ArrayTransform -> Bool
$c>= :: ArrayTransform -> ArrayTransform -> Bool
> :: ArrayTransform -> ArrayTransform -> Bool
$c> :: ArrayTransform -> ArrayTransform -> Bool
<= :: ArrayTransform -> ArrayTransform -> Bool
$c<= :: ArrayTransform -> ArrayTransform -> Bool
< :: ArrayTransform -> ArrayTransform -> Bool
$c< :: ArrayTransform -> ArrayTransform -> Bool
compare :: ArrayTransform -> ArrayTransform -> Ordering
$ccompare :: ArrayTransform -> ArrayTransform -> Ordering
$cp1Ord :: Eq ArrayTransform
Ord)

instance Substitute ArrayTransform where
  substituteNames :: Map VName VName -> ArrayTransform -> ArrayTransform
substituteNames Map VName VName
substs (Rearrange Certificates
cs [Int]
xs) =
    Certificates -> [Int] -> ArrayTransform
Rearrange (Map VName VName -> Certificates -> Certificates
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certificates
cs) [Int]
xs
  substituteNames Map VName VName
substs (Reshape Certificates
cs ShapeChange SubExp
ses) =
    Certificates -> ShapeChange SubExp -> ArrayTransform
Reshape (Map VName VName -> Certificates -> Certificates
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certificates
cs) (Map VName VName -> ShapeChange SubExp -> ShapeChange SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ShapeChange SubExp
ses)
  substituteNames Map VName VName
substs (ReshapeOuter Certificates
cs ShapeChange SubExp
ses) =
    Certificates -> ShapeChange SubExp -> ArrayTransform
ReshapeOuter (Map VName VName -> Certificates -> Certificates
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certificates
cs) (Map VName VName -> ShapeChange SubExp -> ShapeChange SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ShapeChange SubExp
ses)
  substituteNames Map VName VName
substs (ReshapeInner Certificates
cs ShapeChange SubExp
ses) =
    Certificates -> ShapeChange SubExp -> ArrayTransform
ReshapeInner (Map VName VName -> Certificates -> Certificates
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certificates
cs) (Map VName VName -> ShapeChange SubExp -> ShapeChange SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ShapeChange SubExp
ses)
  substituteNames Map VName VName
substs (Replicate Certificates
cs Shape
se) =
    Certificates -> Shape -> ArrayTransform
Replicate (Map VName VName -> Certificates -> Certificates
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certificates
cs) (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
se)

-- | A sequence of array transformations, heavily inspired by
-- "Data.Seq".  You can decompose it using 'viewf' and 'viewl', and
-- grow it by using '|>' and '<|'.  These correspond closely to the
-- similar operations for sequences, except that appending will try to
-- normalise and simplify the transformation sequence.
--
-- The data type is opaque in order to enforce normalisation
-- invariants.  Basically, when you grow the sequence, the
-- implementation will try to coalesce neighboring permutations, for
-- example by composing permutations and removing identity
-- transformations.
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
  deriving (ArrayTransforms -> ArrayTransforms -> Bool
(ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> Eq ArrayTransforms
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransforms -> ArrayTransforms -> Bool
$c/= :: ArrayTransforms -> ArrayTransforms -> Bool
== :: ArrayTransforms -> ArrayTransforms -> Bool
$c== :: ArrayTransforms -> ArrayTransforms -> Bool
Eq, Eq ArrayTransforms
Eq ArrayTransforms
-> (ArrayTransforms -> ArrayTransforms -> Ordering)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms -> ArrayTransforms -> ArrayTransforms)
-> Ord ArrayTransforms
ArrayTransforms -> ArrayTransforms -> Bool
ArrayTransforms -> ArrayTransforms -> Ordering
ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmin :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
max :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmax :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
>= :: ArrayTransforms -> ArrayTransforms -> Bool
$c>= :: ArrayTransforms -> ArrayTransforms -> Bool
> :: ArrayTransforms -> ArrayTransforms -> Bool
$c> :: ArrayTransforms -> ArrayTransforms -> Bool
<= :: ArrayTransforms -> ArrayTransforms -> Bool
$c<= :: ArrayTransforms -> ArrayTransforms -> Bool
< :: ArrayTransforms -> ArrayTransforms -> Bool
$c< :: ArrayTransforms -> ArrayTransforms -> Bool
compare :: ArrayTransforms -> ArrayTransforms -> Ordering
$ccompare :: ArrayTransforms -> ArrayTransforms -> Ordering
$cp1Ord :: Eq ArrayTransforms
Ord, Int -> ArrayTransforms -> ShowS
[ArrayTransforms] -> ShowS
ArrayTransforms -> String
(Int -> ArrayTransforms -> ShowS)
-> (ArrayTransforms -> String)
-> ([ArrayTransforms] -> ShowS)
-> Show ArrayTransforms
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransforms] -> ShowS
$cshowList :: [ArrayTransforms] -> ShowS
show :: ArrayTransforms -> String
$cshow :: ArrayTransforms -> String
showsPrec :: Int -> ArrayTransforms -> ShowS
$cshowsPrec :: Int -> ArrayTransforms -> ShowS
Show)

instance Semigroup ArrayTransforms where
  ArrayTransforms
ts1 <> :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
<> ArrayTransforms
ts2 = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts2 of
    ArrayTransform
t :< ArrayTransforms
ts2' -> (ArrayTransforms
ts1 ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
t) ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts2'
    ViewF
EmptyF -> ArrayTransforms
ts1

instance Monoid ArrayTransforms where
  mempty :: ArrayTransforms
mempty = ArrayTransforms
noTransforms

instance Substitute ArrayTransforms where
  substituteNames :: Map VName VName -> ArrayTransforms -> ArrayTransforms
substituteNames Map VName VName
substs (ArrayTransforms Seq ArrayTransform
ts) =
    Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ArrayTransform -> ArrayTransform
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs (ArrayTransform -> ArrayTransform)
-> Seq ArrayTransform -> Seq ArrayTransform
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq ArrayTransform
ts

-- | The empty transformation list.
noTransforms :: ArrayTransforms
noTransforms :: ArrayTransforms
noTransforms = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty

-- | Is it an empty transformation list?
nullTransforms :: ArrayTransforms -> Bool
nullTransforms :: ArrayTransforms -> Bool
nullTransforms (ArrayTransforms Seq ArrayTransform
s) = Seq ArrayTransform -> Bool
forall a. Seq a -> Bool
Seq.null Seq ArrayTransform
s

-- | Decompose the input-end of the transformation sequence.
viewf :: ArrayTransforms -> ViewF
viewf :: ArrayTransforms -> ViewF
viewf (ArrayTransforms Seq ArrayTransform
s) = case Seq ArrayTransform -> ViewL ArrayTransform
forall a. Seq a -> ViewL a
Seq.viewl Seq ArrayTransform
s of
  ArrayTransform
t Seq.:< Seq ArrayTransform
s' -> ArrayTransform
t ArrayTransform -> ArrayTransforms -> ViewF
:< Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s'
  ViewL ArrayTransform
Seq.EmptyL -> ViewF
EmptyF

-- | A view of the first transformation to be applied.
data ViewF
  = EmptyF
  | ArrayTransform :< ArrayTransforms

-- | Decompose the output-end of the transformation sequence.
viewl :: ArrayTransforms -> ViewL
viewl :: ArrayTransforms -> ViewL
viewl (ArrayTransforms Seq ArrayTransform
s) = case Seq ArrayTransform -> ViewR ArrayTransform
forall a. Seq a -> ViewR a
Seq.viewr Seq ArrayTransform
s of
  Seq ArrayTransform
s' Seq.:> ArrayTransform
t -> Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s' ArrayTransforms -> ArrayTransform -> ViewL
:> ArrayTransform
t
  ViewR ArrayTransform
Seq.EmptyR -> ViewL
EmptyL

-- | A view of the last transformation to be applied.
data ViewL
  = EmptyL
  | ArrayTransforms :> ArrayTransform

-- | Add a transform to the end of the transformation list.
(|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
(|>) = (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ArrayTransforms -> ArrayTransform -> ArrayTransforms
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((ArrayTransform -> ArrayTransforms -> ArrayTransforms)
 -> ArrayTransforms -> ArrayTransform -> ArrayTransforms)
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ArrayTransforms
-> ArrayTransform
-> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (((ArrayTransform, ArrayTransform)
  -> (ArrayTransform, ArrayTransform))
 -> ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ (ArrayTransform
 -> ArrayTransform -> (ArrayTransform, ArrayTransform))
-> (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((ArrayTransform
 -> ArrayTransform -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransform
-> (ArrayTransform, ArrayTransform)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
  where
    extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewL
viewl ArrayTransforms
ts' of
      ViewL
EmptyL -> Maybe (ArrayTransform, ArrayTransforms)
forall a. Maybe a
Nothing
      ArrayTransforms
ts'' :> ArrayTransform
t' -> (ArrayTransform, ArrayTransforms)
-> Maybe (ArrayTransform, ArrayTransforms)
forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
    add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ Seq ArrayTransform
ts' Seq ArrayTransform -> ArrayTransform -> Seq ArrayTransform
forall a. Seq a -> a -> Seq a
Seq.|> ArrayTransform
t'

-- | Add a transform at the beginning of the transformation list.
(<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
<| :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
(<|) = (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
forall a. a -> a
id
  where
    extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts' of
      ViewF
EmptyF -> Maybe (ArrayTransform, ArrayTransforms)
forall a. Maybe a
Nothing
      ArrayTransform
t' :< ArrayTransforms
ts'' -> (ArrayTransform, ArrayTransforms)
-> Maybe (ArrayTransform, ArrayTransforms)
forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
    add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ ArrayTransform
t' ArrayTransform -> Seq ArrayTransform -> Seq ArrayTransform
forall a. a -> Seq a -> Seq a
Seq.<| Seq ArrayTransform
ts'

addTransform' ::
  (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)) ->
  (ArrayTransform -> ArrayTransforms -> ArrayTransforms) ->
  ((ArrayTransform, ArrayTransform) -> (ArrayTransform, ArrayTransform)) ->
  ArrayTransform ->
  ArrayTransforms ->
  ArrayTransforms
addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
t ArrayTransforms
ts =
  ArrayTransforms -> Maybe ArrayTransforms -> ArrayTransforms
forall a. a -> Maybe a -> a
fromMaybe (ArrayTransform
t ArrayTransform -> ArrayTransforms -> ArrayTransforms
`add` ArrayTransforms
ts) (Maybe ArrayTransforms -> ArrayTransforms)
-> Maybe ArrayTransforms -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ do
    (ArrayTransform
t', ArrayTransforms
ts') <- ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts
    ArrayTransform
combined <- (ArrayTransform -> ArrayTransform -> Maybe ArrayTransform)
-> (ArrayTransform, ArrayTransform) -> Maybe ArrayTransform
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms ((ArrayTransform, ArrayTransform) -> Maybe ArrayTransform)
-> (ArrayTransform, ArrayTransform) -> Maybe ArrayTransform
forall a b. (a -> b) -> a -> b
$ (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap (ArrayTransform
t', ArrayTransform
t)
    ArrayTransforms -> Maybe ArrayTransforms
forall a. a -> Maybe a
Just (ArrayTransforms -> Maybe ArrayTransforms)
-> ArrayTransforms -> Maybe ArrayTransforms
forall a b. (a -> b) -> a -> b
$
      if ArrayTransform -> Bool
identityTransform ArrayTransform
combined
        then ArrayTransforms
ts'
        else (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
combined ArrayTransforms
ts'

identityTransform :: ArrayTransform -> Bool
identityTransform :: ArrayTransform -> Bool
identityTransform (Rearrange Certificates
_ [Int]
perm) =
  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
Foldable.and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) [Int]
perm [Int
0 ..]
identityTransform ArrayTransform
_ = Bool
False

combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms (Rearrange Certificates
cs2 [Int]
perm2) (Rearrange Certificates
cs1 [Int]
perm1) =
  ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform)
-> ArrayTransform -> Maybe ArrayTransform
forall a b. (a -> b) -> a -> b
$ Certificates -> [Int] -> ArrayTransform
Rearrange (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2) ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int]
perm2 [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm1
combineTransforms ArrayTransform
_ ArrayTransform
_ = Maybe ArrayTransform
forall a. Maybe a
Nothing

-- | Given an expression, determine whether the expression represents
-- an input transformation of an array variable.  If so, return the
-- variable and the transformation.  Only 'Rearrange' and 'Reshape'
-- are possible to express this way.
transformFromExp :: Certificates -> Exp lore -> Maybe (VName, ArrayTransform)
transformFromExp :: Certificates -> Exp lore -> Maybe (VName, ArrayTransform)
transformFromExp Certificates
cs (BasicOp (Futhark.Rearrange [Int]
perm VName
v)) =
  (VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
cs [Int]
perm)
transformFromExp Certificates
cs (BasicOp (Futhark.Reshape ShapeChange SubExp
shape VName
v)) =
  (VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certificates -> ShapeChange SubExp -> ArrayTransform
Reshape Certificates
cs ShapeChange SubExp
shape)
transformFromExp Certificates
cs (BasicOp (Futhark.Replicate Shape
shape (Futhark.Var VName
v))) =
  (VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certificates -> Shape -> ArrayTransform
Replicate Certificates
cs Shape
shape)
transformFromExp Certificates
_ Exp lore
_ = Maybe (VName, ArrayTransform)
forall a. Maybe a
Nothing

-- | One array input to a SOAC - a SOAC may have multiple inputs, but
-- all are of this form.  Only the array inputs are expressed with
-- this type; other arguments, such as initial accumulator values, are
-- plain expressions.  The transforms are done left-to-right, that is,
-- the first element of the 'ArrayTransform' list is applied first.
data Input = Input ArrayTransforms VName Type
  deriving (Int -> Input -> ShowS
[Input] -> ShowS
Input -> String
(Int -> Input -> ShowS)
-> (Input -> String) -> ([Input] -> ShowS) -> Show Input
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Input] -> ShowS
$cshowList :: [Input] -> ShowS
show :: Input -> String
$cshow :: Input -> String
showsPrec :: Int -> Input -> ShowS
$cshowsPrec :: Int -> Input -> ShowS
Show, Input -> Input -> Bool
(Input -> Input -> Bool) -> (Input -> Input -> Bool) -> Eq Input
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Input -> Input -> Bool
$c/= :: Input -> Input -> Bool
== :: Input -> Input -> Bool
$c== :: Input -> Input -> Bool
Eq, Eq Input
Eq Input
-> (Input -> Input -> Ordering)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Input)
-> (Input -> Input -> Input)
-> Ord Input
Input -> Input -> Bool
Input -> Input -> Ordering
Input -> Input -> Input
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Input -> Input -> Input
$cmin :: Input -> Input -> Input
max :: Input -> Input -> Input
$cmax :: Input -> Input -> Input
>= :: Input -> Input -> Bool
$c>= :: Input -> Input -> Bool
> :: Input -> Input -> Bool
$c> :: Input -> Input -> Bool
<= :: Input -> Input -> Bool
$c<= :: Input -> Input -> Bool
< :: Input -> Input -> Bool
$c< :: Input -> Input -> Bool
compare :: Input -> Input -> Ordering
$ccompare :: Input -> Input -> Ordering
$cp1Ord :: Eq Input
Ord)

instance Substitute Input where
  substituteNames :: Map VName VName -> Input -> Input
substituteNames Map VName VName
substs (Input ArrayTransforms
ts VName
v Type
t) =
    ArrayTransforms -> VName -> Type -> Input
Input
      (Map VName VName -> ArrayTransforms -> ArrayTransforms
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ArrayTransforms
ts)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
v)
      (Map VName VName -> Type -> Type
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Type
t)

-- | Create a plain array variable input with no transformations.
varInput :: HasScope t f => VName -> f Input
varInput :: VName -> f Input
varInput VName
v = Type -> Input
withType (Type -> Input) -> f Type -> f Input
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> f Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  where
    withType :: Type -> Input
withType = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty) VName
v

-- | Create a plain array variable input with no transformations, from an 'Ident'.
identInput :: Ident -> Input
identInput :: Ident -> Input
identInput Ident
v = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty) (Ident -> VName
identName Ident
v) (Ident -> Type
identType Ident
v)

-- | If the given input is a plain variable input, with no transforms,
-- return the variable.
isVarInput :: Input -> Maybe VName
isVarInput :: Input -> Maybe VName
isVarInput (Input ArrayTransforms
ts VName
v Type
_) | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
isVarInput Input
_ = Maybe VName
forall a. Maybe a
Nothing

-- | If the given input is a plain variable input, with no non-vacuous transforms,
-- return the variable.
isVarishInput :: Input -> Maybe VName
isVarishInput :: Input -> Maybe VName
isVarishInput (Input ArrayTransforms
ts VName
v Type
t)
  | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
  | Reshape Certificates
cs [DimCoercion SubExp
_] :< ArrayTransforms
ts' <- ArrayTransforms -> ViewF
viewf ArrayTransforms
ts,
    Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty =
    Input -> Maybe VName
isVarishInput (Input -> Maybe VName) -> Input -> Maybe VName
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
Input ArrayTransforms
ts' VName
v Type
t
isVarishInput Input
_ = Maybe VName
forall a. Maybe a
Nothing

-- | Add a transformation to the end of the transformation list.
addTransform :: ArrayTransform -> Input -> Input
addTransform :: ArrayTransform -> Input -> Input
addTransform ArrayTransform
tr (Input ArrayTransforms
trs VName
a Type
t) =
  ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
trs ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
tr) VName
a Type
t

-- | Add several transformations to the start of the transformation
-- list.
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms ArrayTransforms
ts (Input ArrayTransforms
ots VName
a Type
t) = ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
ts ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ots) VName
a Type
t

-- | Convert SOAC inputs to the corresponding expressions.
inputsToSubExps ::
  (MonadBinder m) =>
  [Input] ->
  m [VName]
inputsToSubExps :: [Input] -> m [VName]
inputsToSubExps = (Input -> m VName) -> [Input] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Input -> m VName
forall (m :: * -> *). MonadBinder m => Input -> m VName
inputToExp'
  where
    inputToExp' :: Input -> m VName
inputToExp' (Input (ArrayTransforms Seq ArrayTransform
ts) VName
a Type
_) =
      (VName -> ArrayTransform -> m VName)
-> VName -> Seq ArrayTransform -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM VName -> ArrayTransform -> m VName
forall (m :: * -> *).
MonadBinder m =>
VName -> ArrayTransform -> m VName
transform VName
a Seq ArrayTransform
ts

    transform :: VName -> ArrayTransform -> m VName
transform VName
ia (Replicate Certificates
cs Shape
n) =
      Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"repeat" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Futhark.Replicate Shape
n (VName -> SubExp
Futhark.Var VName
ia)
    transform VName
ia (Rearrange Certificates
cs [Int]
perm) =
      Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Futhark.Rearrange [Int]
perm VName
ia
    transform VName
ia (Reshape Certificates
cs ShapeChange SubExp
shape) =
      Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"reshape" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Futhark.Reshape ShapeChange SubExp
shape VName
ia
    transform VName
ia (ReshapeOuter Certificates
cs ShapeChange SubExp
shape) = do
      ShapeChange SubExp
shape' <- ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp)
-> (Type -> Shape) -> Type -> ShapeChange SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeChange SubExp) -> m Type -> m (ShapeChange SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ia
      Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"reshape_outer" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Futhark.Reshape ShapeChange SubExp
shape' VName
ia
    transform VName
ia (ReshapeInner Certificates
cs ShapeChange SubExp
shape) = do
      ShapeChange SubExp
shape' <- ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp)
-> (Type -> Shape) -> Type -> ShapeChange SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeChange SubExp) -> m Type -> m (ShapeChange SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ia
      Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"reshape_inner" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Futhark.Reshape ShapeChange SubExp
shape' VName
ia

-- | Return the array name of the input.
inputArray :: Input -> VName
inputArray :: Input -> VName
inputArray (Input ArrayTransforms
_ VName
v Type
_) = VName
v

-- | Return the type of an input.
inputType :: Input -> Type
inputType :: Input -> Type
inputType (Input (ArrayTransforms Seq ArrayTransform
ts) VName
_ Type
at) =
  (Type -> ArrayTransform -> Type)
-> Type -> Seq ArrayTransform -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Type -> ArrayTransform -> Type
transformType Type
at Seq ArrayTransform
ts
  where
    transformType :: Type -> ArrayTransform -> Type
transformType Type
t (Replicate Certificates
_ Shape
shape) =
      Type -> Shape -> Type
arrayOfShape Type
t Shape
shape
    transformType Type
t (Rearrange Certificates
_ [Int]
perm) =
      [Int] -> Type -> Type
rearrangeType [Int]
perm Type
t
    transformType Type
t (Reshape Certificates
_ ShapeChange SubExp
shape) =
      Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` ShapeChange SubExp -> Shape
newShape ShapeChange SubExp
shape
    transformType Type
t (ReshapeOuter Certificates
_ ShapeChange SubExp
shape) =
      let Shape [SubExp]
oldshape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
       in Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 [SubExp]
oldshape)
    transformType Type
t (ReshapeInner Certificates
_ ShapeChange SubExp
shape) =
      let Shape [SubExp]
oldshape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
       in Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
1 [SubExp]
oldshape [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape)

-- | Return the row type of an input.  Just a convenient alias.
inputRowType :: Input -> Type
inputRowType :: Input -> Type
inputRowType = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Input -> Type) -> Input -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType

-- | Return the array rank (dimensionality) of an input.  Just a
-- convenient alias.
inputRank :: Input -> Int
inputRank :: Input -> Int
inputRank = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> (Input -> Type) -> Input -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType

-- | Apply the transformations to every row of the input.
transformRows :: ArrayTransforms -> Input -> Input
transformRows :: ArrayTransforms -> Input -> Input
transformRows (ArrayTransforms Seq ArrayTransform
ts) =
  (Input -> Seq ArrayTransform -> Input)
-> Seq ArrayTransform -> Input -> Input
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Input -> ArrayTransform -> Input)
-> Input -> Seq ArrayTransform -> Input
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Input -> ArrayTransform -> Input
transformRows') Seq ArrayTransform
ts
  where
    transformRows' :: Input -> ArrayTransform -> Input
transformRows' Input
inp (Rearrange Certificates
cs [Int]
perm) =
      ArrayTransform -> Input -> Input
addTransform (Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
cs (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm)) Input
inp
    transformRows' Input
inp (Reshape Certificates
cs ShapeChange SubExp
shape) =
      ArrayTransform -> Input -> Input
addTransform (Certificates -> ShapeChange SubExp -> ArrayTransform
ReshapeInner Certificates
cs ShapeChange SubExp
shape) Input
inp
    transformRows' Input
inp (Replicate Certificates
cs Shape
n)
      | Input -> Int
inputRank Input
inp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
        Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
forall a. Monoid a => a
mempty [Int
1, Int
0]
          ArrayTransform -> Input -> Input
`addTransform` (Certificates -> Shape -> ArrayTransform
Replicate Certificates
cs Shape
n ArrayTransform -> Input -> Input
`addTransform` Input
inp)
      | Bool
otherwise =
        Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
forall a. Monoid a => a
mempty (Int
2 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
3 .. Input -> Int
inputRank Input
inp])
          ArrayTransform -> Input -> Input
`addTransform` ( Certificates -> Shape -> ArrayTransform
Replicate Certificates
cs Shape
n
                             ArrayTransform -> Input -> Input
`addTransform` (Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
forall a. Monoid a => a
mempty (Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. Input -> Int
inputRank Input
inp Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]) ArrayTransform -> Input -> Input
`addTransform` Input
inp)
                         )
    transformRows' Input
inp ArrayTransform
nts =
      String -> Input
forall a. HasCallStack => String -> a
error (String -> Input) -> String -> Input
forall a b. (a -> b) -> a -> b
$ String
"transformRows: Cannot transform this yet:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ArrayTransform -> String
forall a. Show a => a -> String
show ArrayTransform
nts String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Input -> String
forall a. Show a => a -> String
show Input
inp

-- | Add to the input a 'Rearrange' transform that performs an @(k,n)@
-- transposition.  The new transform will be at the end of the current
-- transformation list.
transposeInput :: Int -> Int -> Input -> Input
transposeInput :: Int -> Int -> Input -> Input
transposeInput Int
k Int
n Input
inp =
  ArrayTransform -> Input -> Input
addTransform (Certificates -> [Int] -> ArrayTransform
Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> Int -> [Int] -> [Int]
forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [Int
0 .. Input -> Int
inputRank Input
inp Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]) Input
inp

-- | A definite representation of a SOAC expression.
data SOAC lore
  = Stream SubExp (StreamForm lore) (Lambda lore) [SubExp] [Input]
  | Scatter SubExp (Lambda lore) [Input] [(Shape, Int, VName)]
  | Screma SubExp (ScremaForm lore) [Input]
  | Hist SubExp [HistOp lore] (Lambda lore) [Input]
  deriving (SOAC lore -> SOAC lore -> Bool
(SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool) -> Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC lore -> SOAC lore -> Bool
$c/= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
== :: SOAC lore -> SOAC lore -> Bool
$c== :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
Eq, Int -> SOAC lore -> ShowS
[SOAC lore] -> ShowS
SOAC lore -> String
(Int -> SOAC lore -> ShowS)
-> (SOAC lore -> String)
-> ([SOAC lore] -> ShowS)
-> Show (SOAC lore)
forall lore. Decorations lore => Int -> SOAC lore -> ShowS
forall lore. Decorations lore => [SOAC lore] -> ShowS
forall lore. Decorations lore => SOAC lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SOAC lore] -> ShowS
show :: SOAC lore -> String
$cshow :: forall lore. Decorations lore => SOAC lore -> String
showsPrec :: Int -> SOAC lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SOAC lore -> ShowS
Show)

instance PP.Pretty Input where
  ppr :: Input -> Doc
ppr (Input (ArrayTransforms Seq ArrayTransform
ts) VName
arr Type
_) = (Doc -> ArrayTransform -> Doc) -> Doc -> Seq ArrayTransform -> Doc
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Doc -> ArrayTransform -> Doc
f (VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr) Seq ArrayTransform
ts
    where
      f :: Doc -> ArrayTransform -> Doc
f Doc
e (Rearrange Certificates
cs [Int]
perm) =
        String -> Doc
text String
"rearrange" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Certificates -> Doc
forall a. Pretty a => a -> Doc
ppr Certificates
cs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [[Doc] -> Doc
PP.apply ((Int -> Doc) -> [Int] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Doc
forall a. Pretty a => a -> Doc
ppr [Int]
perm), Doc
e]
      f Doc
e (Reshape Certificates
cs ShapeChange SubExp
shape) =
        String -> Doc
text String
"reshape" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Certificates -> Doc
forall a. Pretty a => a -> Doc
ppr Certificates
cs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [[Doc] -> Doc
PP.apply ((DimChange SubExp -> Doc) -> ShapeChange SubExp -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map DimChange SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeChange SubExp
shape), Doc
e]
      f Doc
e (ReshapeOuter Certificates
cs ShapeChange SubExp
shape) =
        String -> Doc
text String
"reshape_outer" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Certificates -> Doc
forall a. Pretty a => a -> Doc
ppr Certificates
cs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [[Doc] -> Doc
PP.apply ((DimChange SubExp -> Doc) -> ShapeChange SubExp -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map DimChange SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeChange SubExp
shape), Doc
e]
      f Doc
e (ReshapeInner Certificates
cs ShapeChange SubExp
shape) =
        String -> Doc
text String
"reshape_inner" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Certificates -> Doc
forall a. Pretty a => a -> Doc
ppr Certificates
cs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [[Doc] -> Doc
PP.apply ((DimChange SubExp -> Doc) -> ShapeChange SubExp -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map DimChange SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr ShapeChange SubExp
shape), Doc
e]
      f Doc
e (Replicate Certificates
cs Shape
ne) =
        String -> Doc
text String
"replicate" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Certificates -> Doc
forall a. Pretty a => a -> Doc
ppr Certificates
cs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
ne, Doc
e]

instance PrettyLore lore => PP.Pretty (SOAC lore) where
  ppr :: SOAC lore -> Doc
ppr (Screma SubExp
w ScremaForm lore
form [Input]
arrs) = SubExp -> [Input] -> ScremaForm lore -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [inp] -> ScremaForm lore -> Doc
Futhark.ppScrema SubExp
w [Input]
arrs ScremaForm lore
form
  ppr (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [Input]
imgs) =
    SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
Futhark.ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [Input]
imgs
  ppr SOAC lore
soac = String -> Doc
text (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ SOAC lore -> String
forall a. Show a => a -> String
show SOAC lore
soac

-- | Returns the inputs used in a SOAC.
inputs :: SOAC lore -> [Input]
inputs :: SOAC lore -> [Input]
inputs (Stream SubExp
_ StreamForm lore
_ Lambda lore
_ [SubExp]
_ [Input]
arrs) = [Input]
arrs
inputs (Scatter SubExp
_len Lambda lore
_lam [Input]
ivs [(Shape, Int, VName)]
_as) = [Input]
ivs
inputs (Screma SubExp
_ ScremaForm lore
_ [Input]
arrs) = [Input]
arrs
inputs (Hist SubExp
_ [HistOp lore]
_ Lambda lore
_ [Input]
inps) = [Input]
inps

-- | Set the inputs to a SOAC.
setInputs :: [Input] -> SOAC lore -> SOAC lore
setInputs :: [Input] -> SOAC lore -> SOAC lore
setInputs [Input]
arrs (Stream SubExp
w StreamForm lore
form Lambda lore
lam [SubExp]
nes [Input]
_) =
  SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) StreamForm lore
form Lambda lore
lam [SubExp]
nes [Input]
arrs
setInputs [Input]
arrs (Scatter SubExp
w Lambda lore
lam [Input]
_ivs [(Shape, Int, VName)]
as) =
  SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
Scatter ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) Lambda lore
lam [Input]
arrs [(Shape, Int, VName)]
as
setInputs [Input]
arrs (Screma SubExp
w ScremaForm lore
form [Input]
_) =
  SubExp -> ScremaForm lore -> [Input] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
Screma SubExp
w ScremaForm lore
form [Input]
arrs
setInputs [Input]
inps (Hist SubExp
w [HistOp lore]
ops Lambda lore
lam [Input]
_) =
  SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
Hist SubExp
w [HistOp lore]
ops Lambda lore
lam [Input]
inps

newWidth :: [Input] -> SubExp -> SubExp
newWidth :: [Input] -> SubExp -> SubExp
newWidth [] SubExp
w = SubExp
w
newWidth (Input
inp : [Input]
_) SubExp
_ = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp) -> Type -> SubExp
forall a b. (a -> b) -> a -> b
$ Input -> Type
inputType Input
inp

-- | The lambda used in a given SOAC.
lambda :: SOAC lore -> Lambda lore
lambda :: SOAC lore -> Lambda lore
lambda (Stream SubExp
_ StreamForm lore
_ Lambda lore
lam [SubExp]
_ [Input]
_) = Lambda lore
lam
lambda (Scatter SubExp
_len Lambda lore
lam [Input]
_ivs [(Shape, Int, VName)]
_as) = Lambda lore
lam
lambda (Screma SubExp
_ (ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
lam) [Input]
_) = Lambda lore
lam
lambda (Hist SubExp
_ [HistOp lore]
_ Lambda lore
lam [Input]
_) = Lambda lore
lam

-- | Set the lambda used in the SOAC.
setLambda :: Lambda lore -> SOAC lore -> SOAC lore
setLambda :: Lambda lore -> SOAC lore -> SOAC lore
setLambda Lambda lore
lam (Stream SubExp
w StreamForm lore
form Lambda lore
_ [SubExp]
nes [Input]
arrs) =
  SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream SubExp
w StreamForm lore
form Lambda lore
lam [SubExp]
nes [Input]
arrs
setLambda Lambda lore
lam (Scatter SubExp
len Lambda lore
_lam [Input]
ivs [(Shape, Int, VName)]
as) =
  SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len Lambda lore
lam [Input]
ivs [(Shape, Int, VName)]
as
setLambda Lambda lore
lam (Screma SubExp
w (ScremaForm [Scan lore]
scan [Reduce lore]
red Lambda lore
_) [Input]
arrs) =
  SubExp -> ScremaForm lore -> [Input] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
Screma SubExp
w ([Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan lore]
scan [Reduce lore]
red Lambda lore
lam) [Input]
arrs
setLambda Lambda lore
lam (Hist SubExp
w [HistOp lore]
ops Lambda lore
_ [Input]
inps) =
  SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
Hist SubExp
w [HistOp lore]
ops Lambda lore
lam [Input]
inps

-- | The return type of a SOAC.
typeOf :: SOAC lore -> [Type]
typeOf :: SOAC lore -> [Type]
typeOf (Stream SubExp
w StreamForm lore
_ Lambda lore
lam [SubExp]
nes [Input]
_) =
  let accrtps :: [Type]
accrtps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
      arrtps :: [Type]
arrtps =
        [ Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) NoUniqueness
NoUniqueness
          | Type
t <- Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam)
        ]
   in [Type]
accrtps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
arrtps
typeOf (Scatter SubExp
_w Lambda lore
lam [Input]
_ivs [(Shape, Int, VName)]
dests) =
  (Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
typeOf (Screma SubExp
w ScremaForm lore
form [Input]
_) =
  SubExp -> ScremaForm lore -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm lore
form
typeOf (Hist SubExp
_ [HistOp lore]
ops Lambda lore
_ [Input]
_) = do
  HistOp lore
op <- [HistOp lore]
ops
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type]) -> Lambda lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)

-- | The "width" of a SOAC is the expected outer size of its array
-- inputs _after_ input-transforms have been carried out.
width :: SOAC lore -> SubExp
width :: SOAC lore -> SubExp
width (Stream SubExp
w StreamForm lore
_ Lambda lore
_ [SubExp]
_ [Input]
_) = SubExp
w
width (Scatter SubExp
len Lambda lore
_lam [Input]
_ivs [(Shape, Int, VName)]
_as) = SubExp
len
width (Screma SubExp
w ScremaForm lore
_ [Input]
_) = SubExp
w
width (Hist SubExp
w [HistOp lore]
_ Lambda lore
_ [Input]
_) = SubExp
w

-- | Convert a SOAC to the corresponding expression.
toExp ::
  (MonadBinder m, Op (Lore m) ~ Futhark.SOAC (Lore m)) =>
  SOAC (Lore m) ->
  m (Exp (Lore m))
toExp :: SOAC (Lore m) -> m (Exp (Lore m))
toExp SOAC (Lore m)
soac = SOAC (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (SOAC (Lore m) -> Exp (Lore m))
-> m (SOAC (Lore m)) -> m (Exp (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOAC (Lore m) -> m (SOAC (Lore m))
forall (m :: * -> *).
MonadBinder m =>
SOAC (Lore m) -> m (SOAC (Lore m))
toSOAC SOAC (Lore m)
soac

-- | Convert a SOAC to a Futhark-level SOAC.
toSOAC ::
  MonadBinder m =>
  SOAC (Lore m) ->
  m (Futhark.SOAC (Lore m))
toSOAC :: SOAC (Lore m) -> m (SOAC (Lore m))
toSOAC (Stream SubExp
w StreamForm (Lore m)
form Lambda (Lore m)
lam [SubExp]
nes [Input]
inps) =
  SubExp
-> [VName]
-> StreamForm (Lore m)
-> [SubExp]
-> Lambda (Lore m)
-> SOAC (Lore m)
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Futhark.Stream SubExp
w ([VName]
 -> StreamForm (Lore m)
 -> [SubExp]
 -> Lambda (Lore m)
 -> SOAC (Lore m))
-> m [VName]
-> m (StreamForm (Lore m)
      -> [SubExp] -> Lambda (Lore m) -> SOAC (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBinder m => [Input] -> m [VName]
inputsToSubExps [Input]
inps m (StreamForm (Lore m)
   -> [SubExp] -> Lambda (Lore m) -> SOAC (Lore m))
-> m (StreamForm (Lore m))
-> m ([SubExp] -> Lambda (Lore m) -> SOAC (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StreamForm (Lore m) -> m (StreamForm (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure StreamForm (Lore m)
form m ([SubExp] -> Lambda (Lore m) -> SOAC (Lore m))
-> m [SubExp] -> m (Lambda (Lore m) -> SOAC (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes m (Lambda (Lore m) -> SOAC (Lore m))
-> m (Lambda (Lore m)) -> m (SOAC (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda (Lore m) -> m (Lambda (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Lore m)
lam
toSOAC (Scatter SubExp
len Lambda (Lore m)
lam [Input]
ivs [(Shape, Int, VName)]
dests) = do
  [VName]
ivs' <- [Input] -> m [VName]
forall (m :: * -> *). MonadBinder m => [Input] -> m [VName]
inputsToSubExps [Input]
ivs
  SOAC (Lore m) -> m (SOAC (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC (Lore m) -> m (SOAC (Lore m)))
-> SOAC (Lore m) -> m (SOAC (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Lore m)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Lore m)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Futhark.Scatter SubExp
len Lambda (Lore m)
lam [VName]
ivs' [(Shape, Int, VName)]
dests
toSOAC (Screma SubExp
w ScremaForm (Lore m)
form [Input]
arrs) =
  SubExp -> [VName] -> ScremaForm (Lore m) -> SOAC (Lore m)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w ([VName] -> ScremaForm (Lore m) -> SOAC (Lore m))
-> m [VName] -> m (ScremaForm (Lore m) -> SOAC (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBinder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs m (ScremaForm (Lore m) -> SOAC (Lore m))
-> m (ScremaForm (Lore m)) -> m (SOAC (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Lore m) -> m (ScremaForm (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScremaForm (Lore m)
form
toSOAC (Hist SubExp
w [HistOp (Lore m)]
ops Lambda (Lore m)
lam [Input]
inps) =
  SubExp
-> [HistOp (Lore m)] -> Lambda (Lore m) -> [VName] -> SOAC (Lore m)
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Futhark.Hist SubExp
w [HistOp (Lore m)]
ops Lambda (Lore m)
lam ([VName] -> SOAC (Lore m)) -> m [VName] -> m (SOAC (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBinder m => [Input] -> m [VName]
inputsToSubExps [Input]
inps

-- | The reason why some expression cannot be converted to a 'SOAC'
-- value.
data NotSOAC
  = -- | The expression is not a (tuple-)SOAC at all.
    NotSOAC
  deriving (Int -> NotSOAC -> ShowS
[NotSOAC] -> ShowS
NotSOAC -> String
(Int -> NotSOAC -> ShowS)
-> (NotSOAC -> String) -> ([NotSOAC] -> ShowS) -> Show NotSOAC
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NotSOAC] -> ShowS
$cshowList :: [NotSOAC] -> ShowS
show :: NotSOAC -> String
$cshow :: NotSOAC -> String
showsPrec :: Int -> NotSOAC -> ShowS
$cshowsPrec :: Int -> NotSOAC -> ShowS
Show)

-- | Either convert an expression to the normalised SOAC
-- representation, or a reason why the expression does not have the
-- valid form.
fromExp ::
  (Op lore ~ Futhark.SOAC lore, HasScope lore m) =>
  Exp lore ->
  m (Either NotSOAC (SOAC lore))
fromExp :: Exp lore -> m (Either NotSOAC (SOAC lore))
fromExp (Op (Futhark.Stream w as form nes lam)) =
  SOAC lore -> Either NotSOAC (SOAC lore)
forall a b. b -> Either a b
Right (SOAC lore -> Either NotSOAC (SOAC lore))
-> ([Input] -> SOAC lore) -> [Input] -> Either NotSOAC (SOAC lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream SubExp
w StreamForm lore
form Lambda lore
lam [SubExp]
nes ([Input] -> Either NotSOAC (SOAC lore))
-> m [Input] -> m (Either NotSOAC (SOAC lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
as
fromExp (Op (Futhark.Scatter len lam ivs as)) =
  SOAC lore -> Either NotSOAC (SOAC lore)
forall a b. b -> Either a b
Right (SOAC lore -> Either NotSOAC (SOAC lore))
-> m (SOAC lore) -> m (Either NotSOAC (SOAC lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len Lambda lore
lam ([Input] -> [(Shape, Int, VName)] -> SOAC lore)
-> m [Input] -> m ([(Shape, Int, VName)] -> SOAC lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
ivs m ([(Shape, Int, VName)] -> SOAC lore)
-> m [(Shape, Int, VName)] -> m (SOAC lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
as)
fromExp (Op (Futhark.Screma w arrs form)) =
  SOAC lore -> Either NotSOAC (SOAC lore)
forall a b. b -> Either a b
Right (SOAC lore -> Either NotSOAC (SOAC lore))
-> ([Input] -> SOAC lore) -> [Input] -> Either NotSOAC (SOAC lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
Screma SubExp
w ScremaForm lore
form ([Input] -> Either NotSOAC (SOAC lore))
-> m [Input] -> m (Either NotSOAC (SOAC lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp (Op (Futhark.Hist w ops lam arrs)) =
  SOAC lore -> Either NotSOAC (SOAC lore)
forall a b. b -> Either a b
Right (SOAC lore -> Either NotSOAC (SOAC lore))
-> ([Input] -> SOAC lore) -> [Input] -> Either NotSOAC (SOAC lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
Hist SubExp
w [HistOp lore]
ops Lambda lore
lam ([Input] -> Either NotSOAC (SOAC lore))
-> m [Input] -> m (Either NotSOAC (SOAC lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp Exp lore
_ = Either NotSOAC (SOAC lore) -> m (Either NotSOAC (SOAC lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (SOAC lore) -> m (Either NotSOAC (SOAC lore)))
-> Either NotSOAC (SOAC lore) -> m (Either NotSOAC (SOAC lore))
forall a b. (a -> b) -> a -> b
$ NotSOAC -> Either NotSOAC (SOAC lore)
forall a b. a -> Either a b
Left NotSOAC
NotSOAC

-- | To-Stream translation of SOACs.
--   Returns the Stream SOAC and the
--   extra-accumulator body-result ident if any.
soacToStream ::
  (MonadFreshNames m, Bindable lore, Op lore ~ Futhark.SOAC lore) =>
  SOAC lore ->
  m (SOAC lore, [Ident])
soacToStream :: SOAC lore -> m (SOAC lore, [Ident])
soacToStream SOAC lore
soac = do
  Param Type
chunk_param <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let chvar :: SubExp
chvar = VName -> SubExp
Futhark.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param
      (Lambda lore
lam, [Input]
inps) = (SOAC lore -> Lambda lore
forall lore. SOAC lore -> Lambda lore
lambda SOAC lore
soac, SOAC lore -> [Input]
forall lore. SOAC lore -> [Input]
inputs SOAC lore
soac)
      w :: SubExp
w = SOAC lore -> SubExp
forall lore. SOAC lore -> SubExp
width SOAC lore
soac
  Lambda lore
lam' <- Lambda lore -> m (Lambda lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda lore
lam
  let arrrtps :: [Type]
arrrtps = SubExp -> Lambda lore -> [Type]
forall lore. SubExp -> Lambda lore -> [Type]
mapType SubExp
w Lambda lore
lam
      -- the chunked-outersize of the array result and input types
      loutps :: [Type]
loutps = [Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arrrtps]
      lintps :: [Type]
lintps = [Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
inputRowType [Input]
inps]

  [Param Type]
strm_inpids <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inp") [Type]
lintps
  -- Treat each SOAC case individually:
  case SOAC lore
soac of
    Screma SubExp
_ ScremaForm lore
form [Input]
_
      | Just Lambda lore
_ <- ScremaForm lore -> Maybe (Lambda lore)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
Futhark.isMapSOAC ScremaForm lore
form -> do
        -- Map(f,a) => is translated in strem's body to:
        -- let strm_resids = map(f,a_ch) in strm_resids
        --
        -- array result and input IDs of the stream's lambda
        [Ident]
strm_resids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps
        let insoac :: SOAC lore
insoac =
              SubExp -> [VName] -> ScremaForm lore -> SOAC lore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
chvar ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) (ScremaForm lore -> SOAC lore) -> ScremaForm lore -> SOAC lore
forall a b. (a -> b) -> a -> b
$
                Lambda lore -> ScremaForm lore
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda lore
lam'
            insbnd :: Stm lore
insbnd = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident]
strm_resids (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op Op lore
SOAC lore
insoac
            strmbdy :: Body lore
strmbdy = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
insbnd) ([SubExp] -> Body lore) -> [SubExp] -> Body lore
forall a b. (a -> b) -> a -> b
$ (Ident -> SubExp) -> [Ident] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
            strmpar :: [Param Type]
strmpar = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
strm_inpids
            strmlam :: Lambda lore
strmlam = [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam lore]
strmpar Body lore
strmbdy [Type]
loutps
            empty_lam :: Lambda lore
empty_lam = [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [] (Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty []) []
        -- map(f,a) creates a stream with NO accumulators
        (SOAC lore, [Ident]) -> m (SOAC lore, [Ident])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream SubExp
w (StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
Disorder Commutativity
Commutative Lambda lore
empty_lam) Lambda lore
strmlam [] [Input]
inps, [])
      | Just ([Scan lore]
scans, Lambda lore
_) <- ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
Futhark.isScanomapSOAC ScremaForm lore
form,
        Futhark.Scan Lambda lore
scan_lam [SubExp]
nes <- [Scan lore] -> Scan lore
forall lore. Bindable lore => [Scan lore] -> Scan lore
Futhark.singleScan [Scan lore]
scans -> do
        -- scanomap(scan_lam,nes,map_lam,a) => is translated in strem's body to:
        -- 1. let (scan0_ids,map_resids)   = scanomap(scan_lam, nes, map_lam, a_ch)
        -- 2. let strm_resids = map (acc `+`,nes, scan0_ids)
        -- 3. let outerszm1id = sizeof(0,strm_resids) - 1
        -- 4. let lasteel_ids = if outerszm1id < 0
        --                      then nes
        --                      else strm_resids[outerszm1id]
        -- 5. let acc'        = acc + lasteel_ids
        --    {acc', strm_resids, map_resids}
        -- the array and accumulator result types
        let scan_arr_ts :: [Type]
scan_arr_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
chvar) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
scan_lam
            map_arr_ts :: [Type]
map_arr_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
loutps
            accrtps :: [Type]
accrtps = Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
scan_lam

        -- array result and input IDs of the stream's lambda
        [Ident]
strm_resids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
scan_arr_ts
        [Ident]
scan0_ids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"resarr0") [Type]
scan_arr_ts
        [Ident]
map_resids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"map_res") [Type]
map_arr_ts

        [Ident]
lastel_ids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"lstel") [Type]
accrtps
        [Ident]
lastel_tmp_ids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"lstel_tmp") [Type]
accrtps
        Ident
empty_arr <- String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"empty_arr" (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
        [Param Type]
inpacc_ids <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
        Ident
outszm1id <- String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"szm1" (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        -- 1. let (scan0_ids,map_resids)  = scanomap(scan_lam,nes,map_lam,a_ch)
        let insbnd :: Stm lore
insbnd =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] ([Ident]
scan0_ids [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
map_resids) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  SubExp -> [VName] -> ScremaForm lore -> SOAC lore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
chvar ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) (ScremaForm lore -> SOAC lore) -> ScremaForm lore -> SOAC lore
forall a b. (a -> b) -> a -> b
$
                    [Scan lore] -> Lambda lore -> ScremaForm lore
forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
Futhark.scanomapSOAC [Lambda lore -> [SubExp] -> Scan lore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Futhark.Scan Lambda lore
scan_lam [SubExp]
nes] Lambda lore
lam'
            -- 2. let outerszm1id = chunksize - 1
            outszm1bnd :: Stm lore
outszm1bnd =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
outszm1id] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  BinOp -> SubExp -> SubExp -> BasicOp
BinOp
                    (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef)
                    (VName -> SubExp
Futhark.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param)
                    (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
            -- 3. let lasteel_ids = ...
            empty_arr_bnd :: Stm lore
empty_arr_bnd =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
empty_arr] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp
                    (IntType -> CmpOp
CmpSlt IntType
Int64)
                    (VName -> SubExp
Futhark.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
outszm1id)
                    (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
            leltmpbnds :: [Stm lore]
leltmpbnds =
              (Ident -> Ident -> Stm lore) -> [Ident] -> [Ident] -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                ( \Ident
lid Ident
arrid ->
                    [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
lid] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                      BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                        VName -> Slice SubExp -> BasicOp
Index (Ident -> VName
identName Ident
arrid) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                          Type -> Slice SubExp -> Slice SubExp
fullSlice
                            (Ident -> Type
identType Ident
arrid)
                            [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Futhark.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
outszm1id]
                )
                [Ident]
lastel_tmp_ids
                [Ident]
scan0_ids
            lelbnd :: Stm lore
lelbnd =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident]
lastel_ids (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> Exp lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If
                  (VName -> SubExp
Futhark.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
empty_arr)
                  (Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty [SubExp]
nes)
                  ( Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
leltmpbnds) ([SubExp] -> Body lore) -> [SubExp] -> Body lore
forall a b. (a -> b) -> a -> b
$
                      (Ident -> SubExp) -> [Ident] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
lastel_tmp_ids
                  )
                  (IfDec (BranchType lore) -> Exp lore)
-> IfDec (BranchType lore) -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Type] -> IfDec ExtType
ifCommon ([Type] -> IfDec ExtType) -> [Type] -> IfDec ExtType
forall a b. (a -> b) -> a -> b
$ (Ident -> Type) -> [Ident] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Type
identType [Ident]
lastel_tmp_ids
        -- 4. let strm_resids = map (acc `+`,nes, scan0_ids)
        Lambda lore
maplam <- [SubExp] -> Lambda lore -> m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore) =>
[SubExp] -> Lambda lore -> m (Lambda lore)
mkMapPlusAccLam ((Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
inpacc_ids) Lambda lore
scan_lam
        let mapbnd :: Stm lore
mapbnd =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident]
strm_resids (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  SubExp -> [VName] -> ScremaForm lore -> SOAC lore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma
                    SubExp
chvar
                    ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
scan0_ids)
                    (Lambda lore -> ScremaForm lore
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda lore
maplam)
        -- 5. let acc'        = acc + lasteel_ids
        Body lore
addlelbdy <-
          Lambda lore -> [SubExp] -> m (Body lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore) =>
Lambda lore -> [SubExp] -> m (Body lore)
mkPlusBnds Lambda lore
scan_lam ([SubExp] -> m (Body lore)) -> [SubExp] -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Futhark.Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
              (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
lastel_ids
        -- Finally, construct the stream
        let (Stms lore
addlelbnd, [SubExp]
addlelres) = (Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
addlelbdy, Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
addlelbdy)
            strmbdy :: Body lore
strmbdy =
              Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore
insbnd, Stm lore
outszm1bnd, Stm lore
empty_arr_bnd, Stm lore
lelbnd, Stm lore
mapbnd] Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
addlelbnd) ([SubExp] -> Body lore) -> [SubExp] -> Body lore
forall a b. (a -> b) -> a -> b
$
                [SubExp]
addlelres [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Ident -> SubExp) -> [Ident] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) ([Ident]
strm_resids [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
map_resids)
            strmpar :: [Param Type]
strmpar = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
            strmlam :: Lambda lore
strmlam = [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam lore]
strmpar Body lore
strmbdy ([Type]
accrtps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
loutps)
        (SOAC lore, [Ident]) -> m (SOAC lore, [Ident])
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream SubExp
w StreamForm lore
forall lore. StreamForm lore
Sequential Lambda lore
strmlam [SubExp]
nes [Input]
inps,
            (Param Type -> Ident) -> [Param Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [Param Type]
inpacc_ids
          )
      | Just ([Reduce lore]
reds, Lambda lore
_) <- ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
Futhark.isRedomapSOAC ScremaForm lore
form,
        Futhark.Reduce Commutativity
comm Lambda lore
lamin [SubExp]
nes <- [Reduce lore] -> Reduce lore
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
Futhark.singleReduce [Reduce lore]
reds -> do
        -- Redomap(+,lam,nes,a) => is translated in strem's body to:
        -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in
        -- 2. let acc'                   = acc + acc0_ids          in
        --    {acc', strm_resids}

        let accrtps :: [Type]
accrtps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
            -- the chunked-outersize of the array result and input types
            loutps' :: [Type]
loutps' = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
loutps
            -- the lambda with proper index
            foldlam :: Lambda lore
foldlam = Lambda lore
lam'
        -- array result and input IDs of the stream's lambda
        [Ident]
strm_resids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps'
        [Param Type]
inpacc_ids <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
        [Ident]
acc0_ids <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"acc0") [Type]
accrtps
        -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in
        let insoac :: SOAC lore
insoac =
              SubExp -> [VName] -> ScremaForm lore -> SOAC lore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma
                SubExp
chvar
                ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids)
                (ScremaForm lore -> SOAC lore) -> ScremaForm lore -> SOAC lore
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
Futhark.redomapSOAC [Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Futhark.Reduce Commutativity
comm Lambda lore
lamin [SubExp]
nes] Lambda lore
foldlam
            insbnd :: Stm lore
insbnd = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] ([Ident]
acc0_ids [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
strm_resids) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op Op lore
SOAC lore
insoac
        -- 2. let acc'     = acc + acc0_ids    in
        Body lore
addaccbdy <-
          Lambda lore -> [SubExp] -> m (Body lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore) =>
Lambda lore -> [SubExp] -> m (Body lore)
mkPlusBnds Lambda lore
lamin ([SubExp] -> m (Body lore)) -> [SubExp] -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Futhark.Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
              (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
acc0_ids
        -- Construct the stream
        let (Stms lore
addaccbnd, [SubExp]
addaccres) = (Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
addaccbdy, Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
addaccbdy)
            strmbdy :: Body lore
strmbdy =
              Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
insbnd Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
addaccbnd) ([SubExp] -> Body lore) -> [SubExp] -> Body lore
forall a b. (a -> b) -> a -> b
$
                [SubExp]
addaccres [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Ident -> SubExp) -> [Ident] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
            strmpar :: [Param Type]
strmpar = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
            strmlam :: Lambda lore
strmlam = [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam lore]
strmpar Body lore
strmbdy ([Type]
accrtps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
loutps')
        Lambda lore
lam0 <- Lambda lore -> m (Lambda lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda lore
lamin
        (SOAC lore, [Ident]) -> m (SOAC lore, [Ident])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
Stream SubExp
w (StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
InOrder Commutativity
comm Lambda lore
lam0) Lambda lore
strmlam [SubExp]
nes [Input]
inps, [])

    -- Otherwise it cannot become a stream.
    SOAC lore
_ -> (SOAC lore, [Ident]) -> m (SOAC lore, [Ident])
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC lore
soac, [])
  where
    mkMapPlusAccLam ::
      (MonadFreshNames m, Bindable lore) =>
      [SubExp] ->
      Lambda lore ->
      m (Lambda lore)
    mkMapPlusAccLam :: [SubExp] -> Lambda lore -> m (Lambda lore)
mkMapPlusAccLam [SubExp]
accs Lambda lore
plus = do
      let ([Param Type]
accpars, [Param Type]
rempars) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
plus
          parbnds :: [Stm lore]
parbnds =
            (Param Type -> SubExp -> Stm lore)
-> [Param Type] -> [SubExp] -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \Param Type
par SubExp
se ->
                  [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet
                    []
                    [Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par]
                    (BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
              )
              [Param Type]
accpars
              [SubExp]
accs
          plus_bdy :: BodyT lore
plus_bdy = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
plus
          newlambdy :: BodyT lore
newlambdy =
            BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
              (BodyT lore -> BodyDec lore
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT lore
plus_bdy)
              ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
parbnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
plus_bdy)
              (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
plus_bdy)
      Lambda lore -> m (Lambda lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda lore -> m (Lambda lore)) -> Lambda lore -> m (Lambda lore)
forall a b. (a -> b) -> a -> b
$ [LParam lore] -> BodyT lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam lore]
rempars BodyT lore
newlambdy ([Type] -> Lambda lore) -> [Type] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
plus

    mkPlusBnds ::
      (MonadFreshNames m, Bindable lore) =>
      Lambda lore ->
      [SubExp] ->
      m (Body lore)
    mkPlusBnds :: Lambda lore -> [SubExp] -> m (Body lore)
mkPlusBnds Lambda lore
plus [SubExp]
accels = do
      Lambda lore
plus' <- Lambda lore -> m (Lambda lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda lore
plus
      let parbnds :: [Stm lore]
parbnds =
            (Param Type -> SubExp -> Stm lore)
-> [Param Type] -> [SubExp] -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \Param Type
par SubExp
se ->
                  [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet
                    []
                    [Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par]
                    (BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
              )
              (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
plus')
              [SubExp]
accels
          body :: Body lore
body = Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
plus'
      Body lore -> m (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> m (Body lore)) -> Body lore -> m (Body lore)
forall a b. (a -> b) -> a -> b
$ Body lore
body {bodyStms :: Stms lore
bodyStms = [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
parbnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body}