{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion.LoopKernel
( FusedKer (..),
newKernel,
inputs,
setInputs,
arrInputs,
transformOutput,
attemptFusion,
SOAC,
MapNest,
)
where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.HORep.MapNest as MapNest
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)
newtype TryFusion a
= TryFusion
( ReaderT
(Scope SOACS)
(StateT VNameSource Maybe)
a
)
deriving
( a -> TryFusion b -> TryFusion a
(a -> b) -> TryFusion a -> TryFusion b
(forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
Functor TryFusion
a -> TryFusion a
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
TryFusion a -> TryFusion b -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion a
TryFusion (a -> b) -> TryFusion a -> TryFusion b
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
$cp1Applicative :: Functor TryFusion
Applicative,
Applicative TryFusion
TryFusion a
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
TryFusion a -> TryFusion a -> TryFusion a
TryFusion a -> TryFusion [a]
TryFusion a -> TryFusion [a]
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: TryFusion a
$cempty :: forall a. TryFusion a
$cp1Alternative :: Applicative TryFusion
Alternative,
Applicative TryFusion
a -> TryFusion a
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
TryFusion a -> (a -> TryFusion b) -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion b
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$cp1Monad :: Applicative TryFusion
Monad,
Monad TryFusion
Monad TryFusion
-> (forall a. String -> TryFusion a) -> MonadFail TryFusion
String -> TryFusion a
forall a. String -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> TryFusion a
$cfail :: forall a. String -> TryFusion a
$cp1MonadFail :: Monad TryFusion
MonadFail,
Monad TryFusion
Applicative TryFusion
TryFusion VNameSource
Applicative TryFusion
-> Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
$cp2MonadFreshNames :: Monad TryFusion
$cp1MonadFreshNames :: Applicative TryFusion
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
tryFusion ::
MonadFreshNames m =>
TryFusion a ->
Scope SOACS ->
m (Maybe a)
tryFusion :: TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)
liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = String -> TryFusion a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
type SOAC = SOAC.SOAC SOACS
type MapNest = MapNest.MapNest SOACS
transformOutput ::
SOAC.ArrayTransforms ->
[VName] ->
[Ident] ->
Binder SOACS ()
transformOutput :: ArrayTransforms -> [VName] -> [Ident] -> Binder SOACS ()
transformOutput ArrayTransforms
ts [VName]
names = ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts
where
descend :: ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts' [Ident]
validents =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts' of
ViewF
SOAC.EmptyF ->
[(VName, Ident)]
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Ident]
validents) (((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ())
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
k, Ident
valident) ->
[VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
k] (Exp (Lore (BinderT SOACS (State VNameSource))) -> Binder SOACS ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
valident
ArrayTransform
t SOAC.:< ArrayTransforms
ts'' -> do
let ([BasicOp]
es, [Certificates]
css) = [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BasicOp, Certificates)] -> ([BasicOp], [Certificates]))
-> [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. (a -> b) -> a -> b
$ (Ident -> (BasicOp, Certificates))
-> [Ident] -> [(BasicOp, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform ArrayTransform
t) [Ident]
validents
mkPat :: Ident -> PatternT Type
mkPat (Ident VName
nm Type
tp) = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
nm Type
tp]
[Type]
opts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type])
-> BinderT SOACS (State VNameSource) [[Type]]
-> BinderT SOACS (State VNameSource) [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BasicOp -> BinderT SOACS (State VNameSource) [Type])
-> [BasicOp] -> BinderT SOACS (State VNameSource) [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BasicOp -> BinderT SOACS (State VNameSource) [Type]
forall lore (m :: * -> *). HasScope lore m => BasicOp -> m [Type]
primOpType [BasicOp]
es
[Ident]
newIds <- [(VName, Type)]
-> ((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
opts) (((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident])
-> ((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ \(VName
k, Type
opt) ->
String -> Type -> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
k) Type
opt
[(Certificates, Ident, BasicOp)]
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Certificates]
-> [Ident] -> [BasicOp] -> [(Certificates, Ident, BasicOp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Certificates]
css [Ident]
newIds [BasicOp]
es) (((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ())
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(Certificates
cs, Ident
ids, BasicOp
e) ->
Certificates -> Binder SOACS () -> Binder SOACS ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder SOACS () -> Binder SOACS ())
-> Binder SOACS () -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind (Ident -> PatternT Type
mkPat Ident
ids) (BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
e)
ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts'' [Ident]
newIds
applyTransform :: SOAC.ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform :: ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform (SOAC.Rearrange Certificates
cs [Int]
perm) Ident
v =
([Int] -> VName -> BasicOp
Rearrange [Int]
perm' (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
where
perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm) [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
v) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
applyTransform (SOAC.Reshape Certificates
cs ShapeChange SubExp
shape) Ident
v =
(ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeInner Certificates
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.Replicate Certificates
cs Shape
n) Ident
v =
(Shape -> SubExp -> BasicOp
Replicate Shape
n (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia Type
iat) =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts' VName
ia Type
iat)
ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing
data FusedKer = FusedKer
{
FusedKer -> SOAC
fsoac :: SOAC,
FusedKer -> Names
inplace :: Names,
FusedKer -> [VName]
fusedVars :: [VName],
FusedKer -> Names
fusedConsumed :: Names,
FusedKer -> Scope SOACS
kernelScope :: Scope SOACS,
FusedKer -> ArrayTransforms
outputTransform :: SOAC.ArrayTransforms,
FusedKer -> [VName]
outNames :: [VName],
FusedKer -> StmAux ()
kerAux :: StmAux ()
}
deriving (Int -> FusedKer -> ShowS
[FusedKer] -> ShowS
FusedKer -> String
(Int -> FusedKer -> ShowS)
-> (FusedKer -> String) -> ([FusedKer] -> ShowS) -> Show FusedKer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FusedKer] -> ShowS
$cshowList :: [FusedKer] -> ShowS
show :: FusedKer -> String
$cshow :: FusedKer -> String
showsPrec :: Int -> FusedKer -> ShowS
$cshowsPrec :: Int -> FusedKer -> ShowS
Show)
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope =
FusedKer :: SOAC
-> Names
-> [VName]
-> Names
-> Scope SOACS
-> ArrayTransforms
-> [VName]
-> StmAux ()
-> FusedKer
FusedKer
{ fsoac :: SOAC
fsoac = SOAC
soac,
inplace :: Names
inplace = Names
consumed,
fusedVars :: [VName]
fusedVars = [],
fusedConsumed :: Names
fusedConsumed = Names
consumed,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms,
outNames :: [VName]
outNames = [VName]
out_nms,
kernelScope :: Scope SOACS
kernelScope = Scope SOACS
scope,
kerAux :: StmAux ()
kerAux = StmAux ()
aux
}
arrInputs :: FusedKer -> S.Set VName
arrInputs :: FusedKer -> Set VName
arrInputs = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (FusedKer -> [VName]) -> FusedKer -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName])
-> (FusedKer -> [Input]) -> FusedKer -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [Input]
inputs
inputs :: FusedKer -> [SOAC.Input]
inputs :: FusedKer -> [Input]
inputs = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC -> [Input]) -> (FusedKer -> SOAC) -> FusedKer -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> SOAC
fsoac
setInputs :: [SOAC.Input] -> FusedKer -> FusedKer
setInputs :: [Input] -> FusedKer -> FusedKer
setInputs [Input]
inps FusedKer
ker = FusedKer
ker {fsoac :: SOAC
fsoac = [Input]
inps [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker}
tryOptimizeSOAC ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeSOAC :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
(SOAC
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC
soac ArrayTransforms
forall a. Monoid a => a
mempty
let ker' :: FusedKer
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedKer -> [Input]
inputs FusedKer
ker) [Input] -> FusedKer -> FusedKer
`setInputs` FusedKer
ker
outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
where
addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
| Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
| Bool
otherwise =
Input
inp
tryOptimizeKernel ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeKernel :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
FusedKer
ker' <- Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedKer
ker
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'
tryExposeInputs ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryExposeInputs :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
(FusedKer
ker', ArrayTransforms
ots) <- [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
outVars FusedKer
ker
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'
else do
(SOAC
soac', ArrayTransforms
ots') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac ArrayTransforms
ots
let outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryExposeInputs could not pull SOAC transforms"
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker =
FusedKer
ker {fsoac :: SOAC
fsoac = SOAC -> SOAC
fixInputTypes' (SOAC -> SOAC) -> SOAC -> SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker}
where
fixInputTypes' :: SOAC -> SOAC
fixInputTypes' SOAC
soac =
(Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac) [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac
fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v Type
_)
| Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v (Type -> Input) -> Type -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v'
fixInputType Input
inp = Input
inp
applyFusionRules ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
applyFusionRules :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
attemptFusion ::
MonadFreshNames m =>
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
m (Maybe FusedKer)
attemptFusion :: Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
(FusedKer -> FusedKer) -> Maybe FusedKer -> Maybe FusedKer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FusedKer -> FusedKer
removeUnusedParamsFromKer
(Maybe FusedKer -> Maybe FusedKer)
-> m (Maybe FusedKer) -> m (Maybe FusedKer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TryFusion FusedKer -> Scope SOACS -> m (Maybe FusedKer)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion
(Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker)
(FusedKer -> Scope SOACS
kernelScope FusedKer
ker)
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer FusedKer
ker =
case SOAC
soac of
SOAC.Screma {} -> FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac'}
SOAC
_ -> FusedKer
ker
where
soac :: SOAC
soac = FusedKer -> SOAC
fsoac FusedKer
ker
l :: Lambda SOACS
l = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac
inps :: [Input]
inps = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
(Lambda SOACS
l', [Input]
inps') = Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps
soac' :: SOAC
soac' =
Lambda SOACS
l'
Lambda SOACS -> SOAC -> SOAC
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` ([Input]
inps' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac)
removeUnusedParams :: Lambda -> [SOAC.Input] -> (Lambda, [SOAC.Input])
removeUnusedParams :: Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps =
(Lambda SOACS
l {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps'}, [Input]
inps')
where
pInps :: [(Param Type, Input)]
pInps = [Param Type] -> [Input] -> [(Param Type, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
l) [Input]
inps
([Param Type]
ps', [Input]
inps') = case ([(Param Type, Input)] -> ([Param Type], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, Input)] -> ([Param Type], [Input]))
-> [(Param Type, Input)] -> ([Param Type], [Input])
forall a b. (a -> b) -> a -> b
$ ((Param Type, Input) -> Bool)
-> [(Param Type, Input)] -> [(Param Type, Input)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param Type -> Bool
used (Param Type -> Bool)
-> ((Param Type, Input) -> Param Type)
-> (Param Type, Input)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, Input) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, Input)]
pInps, [(Param Type, Input)]
pInps) of
(([], []), (Param Type
p, Input
inp) : [(Param Type, Input)]
_) -> ([Param Type
p], [Input
inp])
(([Param Type]
ps_, [Input]
inps_), [(Param Type, Input)]
_) -> ([Param Type]
ps_, [Input]
inps_)
used :: Param Type -> Bool
used Param Type
p = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> Names -> Bool
`nameIn` Names
freeVars
freeVars :: Names
freeVars = BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT SOACS -> Names) -> BodyT SOACS -> Names
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
l
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
fuseSOACwithKer ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
fuseSOACwithKer :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed FusedKer
ker = do
let soac_c :: SOAC
soac_c = FusedKer -> SOAC
fsoac FusedKer
ker
inp_p_arr :: [Input]
inp_p_arr = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac_p
horizFuse :: Bool
horizFuse =
Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty
Bool -> Bool -> Bool
&& SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_c
inp_c_arr :: [Input]
inp_c_arr = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac_c
lam_p :: Lambda SOACS
lam_p = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac_p
lam_c :: Lambda SOACS
lam_c = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac_c
w :: SubExp
w = SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p
returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
success :: [VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_outnms SOAC
res_soac = do
let fusedVars_new :: [VName]
fusedVars_new = FusedKer -> [VName]
fusedVars FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars
Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
res_soac
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC
fsoac = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC -> SOAC
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` SOAC
res_soac,
fusedVars :: [VName]
fusedVars = [VName]
fusedVars_new,
inplace :: Names
inplace = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
fusedConsumed :: Names
fusedConsumed = FusedKer -> Names
fusedConsumed FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
outNames :: [VName]
outNames = [VName]
res_outnms
}
[(VName, Ident)]
outPairs <- [(VName, Type)]
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac_p) (((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)])
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, Type
t) -> do
VName
outVar' <- String -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> TryFusion VName) -> String -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
outVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
(VName, Ident) -> TryFusion (VName, Ident)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
outVar, VName -> Type -> Ident
Ident VName
outVar' Type
t)
let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
([VName]
extra_nms, [Type]
extra_rtps) =
[(VName, Type)] -> ([VName], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Type)] -> ([VName], [Type]))
-> [(VName, Type)] -> ([VName], [Type])
forall a b. (a -> b) -> a -> b
$
((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, Type) -> VName) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> VName
forall a b. (a, b) -> a
fst) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac_p
res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [Type]
lambdaReturnType = Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
res_lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
extra_rtps}
in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker)) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Horizontal fusion is invalid in the presence of output transforms."
case (SOAC
soac_c, SOAC
soac_p) of
(SOAC, SOAC)
_ | SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_c -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC widths must match."
( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedKer
ker
Bool -> Bool -> Bool
|| Bool
horizFuse -> do
let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam_p
[SubExp]
scan_nes_p
[SubExp]
red_nes_p
[Input]
inp_p_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam_c
[SubExp]
scan_nes_c
[SubExp]
red_nes_c
[Input]
inp_c_arr
([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker
unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
[VName] -> SOAC -> TryFusion FusedKer
success
( [VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
)
(SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma
SubExp
w
([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
[Input]
new_inp
( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
)
| Bool
horizFuse -> do
let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
(BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
body' :: BodyT SOACS
body' =
Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
bodyResult :: [SubExp]
bodyResult =
Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
}
lam' :: Lambda SOACS
lam' =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [Type]
lambdaReturnType =
Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_c)
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)
( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(Shape, Int, VName)]
as_p
)
| Bool
horizFuse -> do
let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
where
([a]
xs_indices, [a]
xs_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
([a]
ys_indices, [a]
ys_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
let (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
let body' :: BodyT SOACS
body' =
Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
bodyResult :: [SubExp]
bodyResult = [(Shape, Int, VName)]
-> [SubExp] -> [(Shape, Int, VName)] -> [SubExp] -> [SubExp]
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c) [(Shape, Int, VName)]
as_p (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
}
let lam' :: Lambda SOACS
lam' =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [Type]
lambdaReturnType = [(Shape, Int, VName)]
-> [Type] -> [(Shape, Int, VName)] -> [Type] -> [Type]
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
(SOAC.Scatter {}, SOAC
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC
_, SOAC.Scatter {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
[VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two SEQ streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
[VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
(SOAC.Stream {}, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two PAR streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
form2 Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC
_) -> do
(SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_p
SOAC
soac_p'' <- case StreamForm SOACS
form2 of
Sequential {} -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_p'
StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_p'
if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_p
then String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
else Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p'' Names
soac_p_consumed FusedKer
ker
(SOAC
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just [Scan SOACS]
_ <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form -> do
(SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_p
if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC
soac_p
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p' Names
soac_p_consumed FusedKer
ker
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
(SOAC
_, SOAC.Stream SubExp
_ StreamForm SOACS
form_p Lambda SOACS
_ [SubExp]
_ [Input]
_) -> do
(SOAC
soac_c', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_c
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SOAC
soac_c' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_c) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
SOAC
soac_c'' <- case StreamForm SOACS
form_p of
StreamForm SOACS
Sequential -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_c'
StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_c'
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac_c'', outNames :: [VName]
outNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedKer -> [VName]
outNames FusedKer
ker}
(SOAC, SOAC)
_ -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse"
getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder (Parallel StreamOrd
o Commutativity
_ Lambda lore
_) = StreamOrd
o
getStreamOrder StreamForm lore
Sequential = StreamOrd
InOrder
fuseStreamHelper ::
[VName] ->
Names ->
[VName] ->
[(VName, Ident)] ->
SOAC ->
SOAC ->
TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper
[VName]
out_kernms
Names
unfus_set
[VName]
outVars
[(VName, Ident)]
outPairs
(SOAC.Stream SubExp
w2 StreamForm SOACS
form2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
(SOAC.Stream SubExp
_ StreamForm SOACS
form1 Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) =
if StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form2 StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form1
then String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"fusion conditions not met!"
else do
let chunk1 :: Param Type
chunk1 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1
chunk2 :: Param Type
chunk2 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam2
hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk2, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk1)]
lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1}
lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam20}
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam1'
[]
[SubExp]
nes1
[Input]
inp1_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam2'
[]
[SubExp]
nes2
[Input]
inp2_arr
res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param Type
chunk1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
res_lam'}
unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
StreamForm SOACS
res_form <- StreamForm SOACS
-> StreamForm SOACS -> TryFusion (StreamForm SOACS)
forall (m :: * -> *) lore.
MonadFail m =>
StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm SOACS
form2 StreamForm SOACS
form1
([VName], SOAC) -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return
( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w2 StreamForm SOACS
res_form Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
)
where
mergeForms :: StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm lore
Sequential StreamForm lore
Sequential = StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return StreamForm lore
forall lore. StreamForm lore
Sequential
mergeForms (Parallel StreamOrd
_ Commutativity
comm2 Lambda lore
lam2r) (Parallel StreamOrd
o1 Commutativity
comm1 Lambda lore
lam1r) =
StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm lore -> m (StreamForm lore))
-> StreamForm lore -> m (StreamForm lore)
forall a b. (a -> b) -> a -> b
$ StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o1 (Commutativity
comm1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> Commutativity
comm2) (Lambda lore -> Lambda lore -> Lambda lore
forall lore. Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps Lambda lore
lam1r Lambda lore
lam2r)
mergeForms StreamForm lore
_ StreamForm lore
_ = String -> m (StreamForm lore)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusing sequential to parallel stream disallowed!"
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC
_ SOAC
_ = String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot Fuse Streams!"
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream s :: SOAC
s@(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) = SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
s
toSeqStream (SOAC.Stream SubExp
w Parallel {} Lambda SOACS
l [SubExp]
acc [Input]
inps) =
SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> TryFusion SOAC) -> SOAC -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
forall lore. StreamForm lore
Sequential Lambda SOACS
l [SubExp]
acc [Input]
inps
toSeqStream SOAC
_ = String -> TryFusion SOAC
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"toSeqStream expects a stream, but given a SOAC."
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel Maybe [VName]
inp FusedKer
ker = do
(SOAC
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
startTrans
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
resTrans
}
where
startTrans :: ArrayTransforms
startTrans = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
optimizeSOAC ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC
soac ArrayTransforms
os = do
(Bool, SOAC, ArrayTransforms)
res <- ((Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms))
-> (Bool, SOAC, ArrayTransforms)
-> [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
False, SOAC
soac, ArrayTransforms
os) [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations
case (Bool, SOAC, ArrayTransforms)
res of
(Bool
False, SOAC
_, ArrayTransforms
_) -> String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No optimisation applied"
(Bool
True, SOAC
soac', ArrayTransforms
os') -> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
os')
where
comb :: (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
changed, SOAC
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f =
do
(SOAC
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f Maybe [VName]
inp SOAC
soac' ArrayTransforms
os
(Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, SOAC
soac'', ArrayTransforms
os'')
TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
changed, SOAC
soac', ArrayTransforms
os')
type Optimization =
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizations :: [Optimization]
optimizations :: [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim]
iswim ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
| Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form,
Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pattern, Certificates, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
let nes_idents :: [Ident]
nes_idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
nes_names ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
([Param Type]
scan_acc_params, [Param Type]
scan_elem_params) =
Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
scan_fun
map_params :: [Param Type]
map_params =
(Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params
[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
map_rettype :: [Type]
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
map_fun
scan_body :: BodyT SOACS
scan_body = Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
map_fun
scan_rettype :: [Type]
scan_rettype = Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [Type]
scan_rettype
nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']
let map_body :: BodyT SOACS
map_body =
Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
( Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
)
([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat
map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
map_params BodyT SOACS
map_body [Type]
map_rettype
perm :: [Int]
perm = case Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
map_fun of
[] -> []
Type
t : [Type]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t]
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return
( SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
map_cs [Int]
perm
)
iswim Maybe [VName]
_ SOAC
_ ArrayTransforms
_ =
String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ISWIM does not apply."
removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: Type
t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}
setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: Type
t = Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w = (Type -> Type) -> PatternT Type -> PatternT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)
commonTransforms ::
[VName] ->
[SOAC.Input] ->
(SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
where
inps' :: [(Bool, Input)]
inps' =
[ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
| Input
inp <- [Input]
inps
]
commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
case ((Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
where
inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
(Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
| ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
where
resDims :: Int
resDims = [Type] -> Int
forall shape u. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
[] -> Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam
Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [Type]
forall lore. Nesting lore -> [Type]
MapNest.nestingReturnType Nesting SOACS
nest
minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts
pullRearrange ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange SOAC
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC
soac
SOAC.Rearrange Certificates
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall (m :: * -> *) a. Monad m => a -> m a
return (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let
perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
SOAC
soac' <-
MapNest -> TryFusion SOAC
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
[Input]
inputs' [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull transpose"
pushRearrange ::
[VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC
soac
([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let invertRearrange :: ArrayTransform
invertRearrange = Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
SOAC
soac' <-
MapNest -> TryFusion SOAC
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
[Input]
inputs'
[Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot push transpose"
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall lore.
SubExp -> Lambda lore -> [Nesting lore] -> [Input] -> MapNest lore
MapNest.MapNest
SubExp
w
Lambda SOACS
body
( (Nesting SOACS -> [Type] -> Nesting SOACS)
-> [Nesting SOACS] -> [[Type]] -> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
Nesting SOACS -> [Type] -> Nesting SOACS
forall lore lore. Nesting lore -> [Type] -> Nesting lore
setReturnType
[Nesting SOACS]
nestings
([[Type]] -> [Nesting SOACS]) -> [[Type]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [[Type]] -> [[Type]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[Type]] -> [[Type]]) -> [[Type]] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ ([Type] -> [Type]) -> [Type] -> [[Type]]
forall a. (a -> a) -> a -> [a]
iterate ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType) [Type]
ts
)
[Input]
inps
where
origts :: [Type]
origts = MapNest -> [Type]
forall lore. MapNest lore -> [Type]
MapNest.typeOf MapNest
nest
rearrangeType' :: Type -> Type
rearrangeType' Type
t = [Int] -> Type -> Type
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [Int]
perm) Type
t
ts :: [Type]
ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
rearrangeType' [Type]
origts
setReturnType :: Nesting lore -> [Type] -> Nesting lore
setReturnType Nesting lore
nesting [Type]
t' =
Nesting lore
nesting {nestingReturnType :: [Type]
MapNest.nestingReturnType = [Type]
t'}
fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
[Int]
perm : [[Int]]
_ -> do
[Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
([Int], [Input]) -> Maybe ([Int], [Input])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [Input]
inps')
[[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
where
exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray
inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
| ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certificates
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing
fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
| Int
r <- Input -> Int
SOAC.inputRank Input
inp,
Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
| Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing
pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
| Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
Futhark.isMapSOAC ScremaForm SOACS
form,
SOAC.Reshape Certificates
cs ShapeChange SubExp
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
maplam = do
let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape of
[] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
SubExp
d : [SubExp]
_ -> SubExp
d
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Input -> Input
SOAC.addTransform (ArrayTransform -> Input -> Input)
-> ArrayTransform -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Certificates -> ShapeChange SubExp -> ArrayTransform
SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) [Input]
inps
inputTypes :: [Type]
inputTypes = (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
SOAC.inputType [Input]
inputs'
let outersoac ::
([SOAC.Input] -> SOAC) ->
(SubExp, [SubExp]) ->
TryFusion ([SOAC.Input] -> SOAC)
outersoac :: ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac [Input] -> SOAC
inner (SubExp
w, [SubExp]
outershape) = do
let addDims :: Type -> Type
addDims Type
t = Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
retTypes :: [Type]
retTypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
addDims ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
maplam
[Param Type]
ps <- [Type]
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
inputTypes ((Type -> TryFusion (Param Type)) -> TryFusion [Param Type])
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall a b. (a -> b) -> a -> b
$ \Type
inpt ->
String -> Type -> TryFusion (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"pullReshape_param" (Type -> TryFusion (Param Type)) -> Type -> TryFusion (Param Type)
forall a b. (a -> b) -> a -> b
$
Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) Type
inpt
BodyT SOACS
inner_body <-
Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
[BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
SOACS
(State VNameSource)
(Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m)) =>
SOAC (Lore m) -> m (Exp (Lore m))
SOAC.toExp (SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource)))))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC
inner ([Input] -> SOAC) -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ (Param Type -> Input) -> [Param Type] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input) -> (Param Type -> Ident) -> Param Type -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param Type]
ps]
let inner_fun :: Lambda SOACS
inner_fun =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps,
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
retTypes,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
inner_body
}
([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return (([Input] -> SOAC) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
inner_fun
[Input] -> SOAC
op' <-
(([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input] -> SOAC
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC
_ ArrayTransforms
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull reshape"
exposeInputs ::
[VName] ->
FusedKer ->
TryFusion (FusedKer, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
inpIds FusedKer
ker =
(FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pushRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pullRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker
where
ot :: ArrayTransforms
ot = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
pushRearrange' :: TryFusion FusedKer
pushRearrange' = do
(SOAC
soac', ArrayTransforms
ot') <- [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
ot'
}
pullRearrange' :: TryFusion FusedKer
pullRearrange' = do
(SOAC
soac', ArrayTransforms
ot') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"pullRearrange was not enough"
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms
}
exposeInputs' :: FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker' =
case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [Input]
inputs FusedKer
ker' of
(ArrayTransforms
ot', [Input]
inps')
| (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
(FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer
ker' {fsoac :: SOAC
fsoac = [Input]
inps' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker'}, ArrayTransforms
ot')
(ArrayTransforms, [Input])
_ -> String -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot expose"
exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds
outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange, SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape]
pullOutputTransforms ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
forall t t.
[t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers
where
attempt :: [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [] t
_ t
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull anything"
attempt (t -> t -> TryFusion (SOAC, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps) t
soac t
ots =
do
(SOAC
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC, ArrayTransforms)
p t
soac t
ots
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
SOAC.noTransforms)
else SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac' ArrayTransforms
ots' TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps t
soac t
ots