{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}
module Futhark.Pass.ExtractKernels.DistributeNests
( MapLoop (..),
mapLoopStm,
bodyContainsParallelism,
lambdaContainsParallelism,
determineReduceOp,
histKernel,
DistEnv (..),
DistAcc (..),
runDistNestT,
DistNestT,
liftInner,
distributeMap,
distribute,
distributeSingleStm,
distributeMapBodyStms,
addStmsToAcc,
addStmToAcc,
permutationAndMissing,
addPostStms,
postStm,
inNesting,
)
where
import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.Function ((&))
import Data.List (find, partition, tails)
import qualified Data.Map as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log
scopeForSOACs :: SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs :: Scope lore -> Scope SOACS
scopeForSOACs = Scope lore -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope
data MapLoop = MapLoop SOACS.Pattern (StmAux ()) SubExp SOACS.Lambda [VName]
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pattern
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) =
Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam
data DistEnv lore m = DistEnv
{ DistEnv lore m -> Nestings
distNest :: Nestings,
DistEnv lore m -> Scope lore
distScope :: Scope lore,
DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms :: Stms SOACS -> DistNestT lore m (Stms lore),
DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap ::
MapLoop ->
DistAcc lore ->
DistNestT lore m (DistAcc lore),
DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms :: Stm SOACS -> Binder lore (Stms lore),
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda :: Lambda SOACS -> Binder lore (Lambda lore),
DistEnv lore m -> MkSegLevel lore m
distSegLevel :: MkSegLevel lore m
}
data DistAcc lore = DistAcc
{ DistAcc lore -> Targets
distTargets :: Targets,
DistAcc lore -> Stms lore
distStms :: Stms lore
}
data DistRes lore = DistRes
{ DistRes lore -> PostStms lore
accPostStms :: PostStms lore,
DistRes lore -> Log
accLog :: Log
}
instance Semigroup (DistRes lore) where
DistRes PostStms lore
ks1 Log
log1 <> :: DistRes lore -> DistRes lore -> DistRes lore
<> DistRes PostStms lore
ks2 Log
log2 =
PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes (PostStms lore
ks1 PostStms lore -> PostStms lore -> PostStms lore
forall a. Semigroup a => a -> a -> a
<> PostStms lore
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)
instance Monoid (DistRes lore) where
mempty :: DistRes lore
mempty = PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes PostStms lore
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty
newtype PostStms lore = PostStms {PostStms lore -> Stms lore
unPostStms :: Stms lore}
instance Semigroup (PostStms lore) where
PostStms Stms lore
xs <> :: PostStms lore -> PostStms lore -> PostStms lore
<> PostStms Stms lore
ys = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms (Stms lore -> PostStms lore) -> Stms lore -> PostStms lore
forall a b. (a -> b) -> a -> b
$ Stms lore
ys Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
xs
instance Monoid (PostStms lore) where
mempty :: PostStms lore
mempty = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
forall a. Monoid a => a
mempty
typeEnvFromDistAcc :: DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc :: DistAcc lore -> Scope lore
typeEnvFromDistAcc = PatternT Type -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (DistAcc lore -> PatternT Type) -> DistAcc lore -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistAcc lore -> (PatternT Type, Result))
-> DistAcc lore
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> (DistAcc lore -> Targets)
-> DistAcc lore
-> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc Stms lore
stms DistAcc lore
acc =
DistAcc lore
acc {distStms :: Stms lore
distStms = Stms lore
stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc}
addStmToAcc ::
(MonadFreshNames m, DistLore lore) =>
Stm SOACS ->
DistAcc lore ->
DistNestT lore m (DistAcc lore)
addStmToAcc :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc = do
Stm SOACS -> Binder lore (Stms lore)
onSoacs <- (DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore))
-> DistNestT lore m (Stm SOACS -> Binder lore (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms
(Stms lore
stm', Stms lore
_) <- Binder lore (Stms lore) -> DistNestT lore m (Stms lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore))
-> Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder lore (Stms lore)
onSoacs Stm SOACS
stm
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc {distStms :: Stms lore
distStms = Stms lore
stm' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc}
soacsLambda ::
(MonadFreshNames m, DistLore lore) =>
Lambda SOACS ->
DistNestT lore m (Lambda lore)
soacsLambda :: Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam = do
Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
(Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst ((Lambda lore, Stms lore) -> Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
-> DistNestT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Binder lore (Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Lambda -> Binder lore (Lambda lore)
onLambda Lambda
lam)
newtype DistNestT lore m a
= DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
deriving
( a -> DistNestT lore m b -> DistNestT lore m a
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
(forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b. a -> DistNestT lore m b -> DistNestT lore m a)
-> Functor (DistNestT lore m)
forall a b. a -> DistNestT lore m b -> DistNestT lore m a
forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT lore m b -> DistNestT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
fmap :: (a -> b) -> DistNestT lore m a -> DistNestT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
Functor,
Functor (DistNestT lore m)
a -> DistNestT lore m a
Functor (DistNestT lore m)
-> (forall a. a -> DistNestT lore m a)
-> (forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a)
-> Applicative (DistNestT lore m)
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
*> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
liftA2 :: (a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
<*> :: DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
pure :: a -> DistNestT lore m a
$cpure :: forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
$cp1Applicative :: forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
Applicative,
Applicative (DistNestT lore m)
a -> DistNestT lore m a
Applicative (DistNestT lore m)
-> (forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a. a -> DistNestT lore m a)
-> Monad (DistNestT lore m)
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
>> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
>>= :: DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$cp1Monad :: forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
Monad,
MonadReader (DistEnv lore m),
MonadWriter (DistRes lore)
)
liftInner :: (LocalScope lore m, DistLore lore) => m a -> DistNestT lore m a
liftInner :: m a -> DistNestT lore m a
liftInner m a
m = do
Scope lore
outer_scope <- DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$
WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall a b. (a -> b) -> a -> b
$
m a -> WriterT (DistRes lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT (DistRes lore) m a)
-> m a -> WriterT (DistRes lore) m a
forall a b. (a -> b) -> a -> b
$ do
Scope lore
inner_scope <- m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
Scope lore -> m a -> m a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope lore
outer_scope Scope lore -> Scope lore -> Scope lore
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope lore
inner_scope) m a
m
instance MonadFreshNames m => MonadFreshNames (DistNestT lore m) where
getNameSource :: DistNestT lore m VNameSource
getNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes lore) m VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes lore) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
putNameSource :: VNameSource -> DistNestT lore m ()
putNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ()
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ())
-> (VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> VNameSource
-> DistNestT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> (VNameSource -> WriterT (DistRes lore) m ())
-> VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes lore) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource
instance (Monad m, ASTLore lore) => HasScope lore (DistNestT lore m) where
askScope :: DistNestT lore m (Scope lore)
askScope = (DistEnv lore m -> Scope lore) -> DistNestT lore m (Scope lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope
instance (Monad m, ASTLore lore) => LocalScope lore (DistNestT lore m) where
localScope :: Scope lore -> DistNestT lore m a -> DistNestT lore m a
localScope Scope lore
types = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env {distScope :: Scope lore
distScope = Scope lore
types Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env}
instance Monad m => MonadLogger (DistNestT lore m) where
addLog :: Log -> DistNestT lore m ()
addLog Log
msgs = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes lore
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
msgs}
runDistNestT ::
(MonadLogger m, DistLore lore) =>
DistEnv lore m ->
DistNestT lore m (DistAcc lore) ->
m (Stms lore)
runDistNestT :: DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv lore m
env (DistNestT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m) = do
(DistAcc lore
acc, DistRes lore
res) <- WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore))
-> WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
-> DistEnv lore m -> WriterT (DistRes lore) m (DistAcc lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m DistEnv lore m
env
Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes lore -> Log
forall lore. DistRes lore -> Log
accLog DistRes lore
res
Stms lore -> m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> m (Stms lore)) -> Stms lore -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$
PostStms lore -> Stms lore
forall lore. PostStms lore -> Stms lore
unPostStms (DistRes lore -> PostStms lore
forall lore. DistRes lore -> PostStms lore
accPostStms DistRes lore
res)
Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> (PatternT Type, Result) -> Stms lore
identityStms (Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> Targets -> (PatternT Type, Result)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)
where
outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
case DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env of
(Nesting
nest, []) -> Nesting
nest
(Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
params_to_arrs :: [(VName, VName)]
params_to_arrs =
((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall dec. Param dec -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost
identityStms :: (PatternT Type, Result) -> Stms lore
identityStms (PatternT Type
rem_pat, Result
res) =
[Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore) -> [Stm lore] -> Stms lore
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
identityStm (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
rem_pat) Result
res
identityStm :: PatElemT Type -> SubExp -> Stm lore
identityStm PatElemT Type
pe (Var VName
v)
| Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
identityStm PatElemT Type
pe SubExp
se =
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se
addPostStms :: Monad m => PostStms lore -> DistNestT lore m ()
addPostStms :: PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
ks = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes lore -> DistNestT lore m ())
-> DistRes lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ DistRes Any
forall a. Monoid a => a
mempty {accPostStms :: PostStms lore
accPostStms = PostStms lore
ks}
postStm :: Monad m => Stms lore -> DistNestT lore m ()
postStm :: Stms lore -> DistNestT lore m ()
postStm Stms lore
stms = PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms (PostStms lore -> DistNestT lore m ())
-> PostStms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
stms
withStm ::
(Monad m, DistLore lore) =>
Stm SOACS ->
DistNestT lore m a ->
DistNestT lore m a
withStm :: Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env
{ distScope :: Scope lore
distScope =
Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Stm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm SOACS
stm) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env,
distNest :: Nestings
distNest =
Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
}
where
provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm
leavingNesting ::
(MonadFreshNames m, DistLore lore) =>
DistAcc lore ->
DistNestT lore m (DistAcc lore)
leavingNesting :: DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting DistAcc lore
acc =
case Targets -> Maybe ((PatternT Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((PatternT Type, Result), Targets))
-> Targets -> Maybe ((PatternT Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc of
Maybe ((PatternT Type, Result), Targets)
Nothing ->
[Char] -> DistNestT lore m (DistAcc lore)
forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
Just ((PatternT Type
pat, Result
res), Targets
newtargets)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Seq (Stm lore) -> Bool) -> Seq (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Seq (Stm lore)
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc -> do
(Nesting Names
_ LoopNesting
inner, [Nesting]
_) <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
let MapNesting PatternT Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs = LoopNesting
inner
body :: BodyT lore
body = BodyDec lore -> Seq (Stm lore) -> Result -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (DistAcc lore -> Seq (Stm lore)
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) Result
res
used_in_body :: Names
used_in_body = BodyT lore -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT lore
body
([Param Type]
used_params, [VName]
used_arrs) =
[(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
lam' :: LambdaT lore
lam' =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam lore]
lambdaParams = [Param Type]
[LParam lore]
used_params,
lambdaBody :: BodyT lore
lambdaBody = BodyT lore
body,
lambdaReturnType :: [Type]
lambdaReturnType = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
pat
}
Seq (Stm lore)
stms <-
Binder lore () -> DistNestT lore m (Seq (Stm lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Seq (Stm lore)))
-> (SOAC lore -> Binder lore ())
-> SOAC lore
-> DistNestT lore m (Seq (Stm lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux () -> Binder lore () -> Binder lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
aux (Binder lore () -> Binder lore ())
-> (SOAC lore -> Binder lore ()) -> SOAC lore -> Binder lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern (Lore (BinderT lore (State VNameSource)))
-> SOAC (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (SOAC lore -> DistNestT lore m (Seq (Stm lore)))
-> SOAC lore -> DistNestT lore m (Seq (Stm lore))
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm lore -> SOAC lore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
used_arrs (ScremaForm lore -> SOAC lore) -> ScremaForm lore -> SOAC lore
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> ScremaForm lore
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT lore
lam'
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm lore)
distStms = Seq (Stm lore)
stms}
| Bool
otherwise -> do
(Nesting Names
_ LoopNesting
inner_nesting, [Nesting]
_) <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
let w :: SubExp
w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
inps :: [(Param Type, VName)]
inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting
remnantStm :: PatElemT Type -> SubExp -> Stm lore
remnantStm PatElemT Type
pe (Var VName
v)
| Just (Param Type
_, VName
arr) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec lore)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
remnantStm PatElemT Type
pe SubExp
se =
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec lore)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se
stms :: Seq (Stm lore)
stms =
[Stm lore] -> Seq (Stm lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Seq (Stm lore)) -> [Stm lore] -> Seq (Stm lore)
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
remnantStm (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
pat) Result
res
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm lore)
distStms = Seq (Stm lore)
stms}
mapNesting ::
(MonadFreshNames m, DistLore lore) =>
PatternT Type ->
StmAux () ->
SubExp ->
Lambda SOACS ->
[VName] ->
DistNestT lore m (DistAcc lore) ->
DistNestT lore m (DistAcc lore)
mapNesting :: PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
mapNesting PatternT Type
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs DistNestT lore m (DistAcc lore)
m =
(DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv lore m -> DistEnv lore m
extend (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT lore m (DistAcc lore)
m
where
nest :: Nesting
nest =
Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
[Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
extend :: DistEnv lore m -> DistEnv lore m
extend DistEnv lore m
env =
DistEnv lore m
env
{ distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env,
distScope :: Scope lore
distScope = Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
}
inNesting ::
(Monad m, DistLore lore) =>
KernelNest ->
DistNestT lore m a ->
DistNestT lore m a
inNesting :: KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env
{ distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests'),
distScope :: Scope lore
distScope = (LoopNesting -> Scope lore) -> [LoopNesting] -> Scope lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope lore
forall lore. DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting (LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
}
where
(Nesting
inner, [Nesting]
nests') =
case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
[] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
(LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
where
isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
Exp SOACS -> Bool
isMap (Stm SOACS -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm SOACS
stm)
Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm))
isMap :: Exp SOACS -> Bool
isMap Op {} = Bool
True
isMap (DoLoop [(FParam SOACS, SubExp)]
_ [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
isMap Exp SOACS
_ = Bool
False
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> (Lambda -> Body SOACS) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody
distributeMapBodyStms ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore ->
Stms SOACS ->
DistNestT lore m (DistAcc lore)
distributeMapBodyStms :: DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
orig_acc = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
(MonadFreshNames m, Bindable lore, HasSegOp lore, BinderOps lore,
LocalScope lore m, ExpDec lore ~ (), LetDec lore ~ Type,
BodyDec lore ~ ()) =>
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
orig_acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList
where
onStms :: DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [] = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc
onStms DistAcc lore
acc (Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream w arrs Sequential accs lam)) : [Stm SOACS]
stms) = do
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
Stms SOACS
stream_stms <-
((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w Result
accs LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
(SymbolTable (Wise SOACS)
_, Stms SOACS
stream_stms') <-
ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistNestT lore m (SymbolTable (Wise SOACS), Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms SOACS
stream_stms') [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
onStms DistAcc lore
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
Stm SOACS
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [Stm SOACS]
stms
onInnerMap :: Monad m => MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap MapLoop
loop DistAcc lore
acc = do
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f <- (DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT
lore m (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f MapLoop
loop DistAcc lore
acc
onTopLevelStms :: Monad m => Stms SOACS -> DistNestT lore m ()
onTopLevelStms :: Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms = do
Stms SOACS -> DistNestT lore m (Stms lore)
f <- (DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms SOACS -> DistNestT lore m (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT lore m (Stms lore)
f Stms SOACS
stms
maybeDistributeStm ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS ->
DistAcc lore ->
DistNestT lore m (DistAcc lore)
maybeDistributeStm :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm DistAcc lore
acc
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm) =
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc lore
acc
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> Stms SOACS)
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
(Stms SOACS -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (Stms SOACS) -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS () -> DistNestT lore m (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Pattern (Lore (BinderT SOACS (State VNameSource)))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS (State VNameSource)))
Pattern
pat Op SOACS
SOAC (Lore (BinderT SOACS (State VNameSource)))
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
_ (Op (Screma w arrs form))) DistAcc lore
acc
| Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc DistNestT lore m (Maybe (DistAcc lore))
-> (Maybe (DistAcc lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (DistAcc lore)
Nothing -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
Just DistAcc lore
acc' -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
Monad m =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap (Pattern -> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern
pat (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm) SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc'
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpDec SOACS)
aux (DoLoop [] [(FParam SOACS, SubExp)]
val form :: LoopForm SOACS
form@ForLoop {} Body SOACS
body)) DistAcc lore
acc
| [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT Type
Pattern
pat),
Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
|
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
(LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> StmAux () -> Names
forall a. FreeIn a => a -> Names
freeIn StmAux ()
StmAux (ExpDec SOACS)
aux)
Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
Stms SOACS
stms <-
(ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
((SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS)
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> (Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS))
-> Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms
(Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pattern
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pattern
pat [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)
Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
_ (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) DistAcc lore
acc
| [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT Type
Pattern
pat),
Body SOACS -> Bool
bodyContainsParallelism Body SOACS
tbranch Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyContainsParallelism Body SOACS
fbranch
Bool -> Bool -> Bool
|| Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfDec (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfDec rt -> [rt]
ifReturns IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret)) =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
(SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfDec (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
let branch :: Branch
branch = [Int]
-> Pattern
-> SubExp
-> Body SOACS
-> Body SOACS
-> IfDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pattern
pat SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret
Stms SOACS
stms <-
(ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
((SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS)
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> (Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS))
-> Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms
(Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch
Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Op (Screma w arrs form))) DistAcc lore
acc
| Just [Reduce Commutativity
comm Lambda
lam Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
Just BinderT SOACS (DistNestT lore m) ()
m <- Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS (DistNestT lore m) ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
(()
_, Stms SOACS
bnds) <- BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (StmAux ()
-> BinderT SOACS (DistNestT lore m) ()
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BinderT SOACS (DistNestT lore m) ()
m) Scope SOACS
types
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc Stms SOACS
bnds
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter w lam ivs as))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest' [Int]
perm PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda lore
lam' [VName]
ivs [(Shape, Int, VName)]
as
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Hist w ops lam as))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest' [Int]
perm Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda lore
lam' [VName]
as
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm
stm :: Stm SOACS
stm@( Let
(Pattern [] [PatElemT (LetDec SOACS)
pe])
StmAux (ExpDec SOACS)
aux
(BasicOp (Index VName
arr Slice SubExp
slice))
)
DistAcc lore
acc
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
VName -> SubExp
Var (PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElemT (LetDec SOACS)
pe) SubExp -> Result -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (PatternT Type, Result) -> Result
forall a b. (a, b) -> b
snd (Targets -> (PatternT Type, Result)
innerTarget (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)) =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
_res, KernelNest
nest, DistAcc lore
acc') ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
segmentedGatherKernel KernelNest
nest (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc lore
acc
| Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda
lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest' [Int]
perm SubExp
w Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs
DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc lore
acc
| Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
Reduce Commutativity
comm Lambda
lam Result
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
let comm' :: Commutativity
comm'
| Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm SubExp
w Commutativity
comm' Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs
DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc lore
acc = do
Scope SOACS
scope <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d : Result
ds)) SubExp
v))) DistAcc lore
acc
| [Type
t] <- PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern
pat = do
VName
tmp <- [Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
newbnd :: Stm SOACS
newbnd = Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
d [] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam
tmpbnd :: Stm SOACS
tmpbnd =
Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
tmp Type
rowt]) StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
ds) SubExp
v
lam :: Lambda
lam =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt],
lambdaParams :: [LParam SOACS]
lambdaParams = [],
lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm SOACS
tmpbnd) [VName -> SubExp
Var VName
tmp]
}
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
newbnd DistAcc lore
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Copy VName
stm_arr))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pattern [] [PatElemT (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque (Var VName
stm_arr)))) DistAcc lore
acc
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElemT (LetDec SOACS)
pe =
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
stm_arr))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
perm' :: [Int]
perm' = [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
VName
arr' <- [Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> DistNestT lore m VName)
-> [Char] -> DistNestT lore m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
Type
arr_t <- VName -> DistNestT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
[Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList
[ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
arr' Type
arr_t]) StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr'
]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
stm_arr))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let reshape' :: ShapeChange SubExp
reshape' =
(SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> Result
kernelNestWidths KernelNest
nest)
ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> Result
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rotate Result
rots VName
stm_arr))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let rots' :: Result
rots' = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (KernelNest -> Result
kernelNestWidths KernelNest
nest) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
rots
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> BasicOp
Rotate Result
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc lore
acc
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res -> do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm
(Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice VName
v
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update VName
arr [DimFix SubExp
i] SubExp
v))) DistAcc lore
acc
| [Type
t] <- PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern
pat,
Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp lore -> Bool
forall lore. ExpT lore -> Bool
amortises (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc = do
let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t
et :: Type
et = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t
lam :: Lambda
lam =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [],
lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
et],
lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
forall a. Monoid a => a
mempty [SubExp
i, SubExp
v]
}
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm (Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda -> [VName] -> [(Shape, Int, VName)] -> SOAC SOACS
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) Lambda
lam [] [(Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w], Int
1, VName
arr)]) DistAcc lore
acc
where
amortises :: ExpT lore -> Bool
amortises DoLoop {} = Bool
True
amortises Op {} = Bool
True
amortises ExpT lore
_ = Bool
False
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
_, KernelNest
nest, DistAcc lore
acc') ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest
DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
where
segmentedConcat :: KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest =
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int
0] Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
_ [KernelInput]
_ Result
_ (VName
x' : [VName]
xs') ->
let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
in Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT lore m)) -> BinderT lore m ())
-> Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d' VName
x' [VName]
xs' SubExp
w
maybeDistributeStm Stm SOACS
bnd DistAcc lore
acc =
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
distributeSingleUnaryStm ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore ->
Stm SOACS ->
VName ->
(KernelNest -> PatternT Type -> VName -> DistNestT lore m (Stms lore)) ->
DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm :: DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
(LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
[(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
arr_p),
VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
let outerpat :: PatternT Type
outerpat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f KernelNest
nest PatternT Type
outerpat VName
arr
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
where
perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
| [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
VName
arr VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr' =
case [LoopNesting]
nest of
[] -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm_arr
LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
| Bool
otherwise =
Bool
False
distribute ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore ->
DistNestT lore m (DistAcc lore)
distribute :: DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute DistAcc lore
acc =
DistAcc lore -> Maybe (DistAcc lore) -> DistAcc lore
forall a. a -> Maybe a -> a
fromMaybe DistAcc lore
acc (Maybe (DistAcc lore) -> DistAcc lore)
-> DistNestT lore m (Maybe (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc
mkSegLevel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel :: DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel = do
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m)))
-> MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall a b. (a -> b) -> a -> b
$ \Result
w [Char]
desc ThreadRecommendation
r -> do
(SegOpLevel lore
lvl, Stms lore
stms) <- DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore))
-> DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore))
-> m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ BinderT lore m (SegOpLevel lore) -> m (SegOpLevel lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT lore m (SegOpLevel lore)
-> m (SegOpLevel lore, Stms lore))
-> BinderT lore m (SegOpLevel lore)
-> m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl Result
w [Char]
desc ThreadRecommendation
r
Stms (Lore (BinderT lore (DistNestT lore m)))
-> BinderT lore (DistNestT lore m) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (DistNestT lore m)))
stms
SegOpLevel lore
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
forall (m :: * -> *) a. Monad m => a -> m a
return SegOpLevel lore
lvl
distributeIfPossible ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore ->
DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible :: DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc = do
Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
-> DistNestT lore m (Maybe (DistAcc lore)))
-> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms lore)
Nothing -> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DistAcc lore)
forall a. Maybe a
Nothing
Just (Targets
targets, Stms lore
kernel) -> do
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm Stms lore
kernel
Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore)))
-> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall a b. (a -> b) -> a -> b
$
DistAcc lore -> Maybe (DistAcc lore)
forall a. a -> Maybe a
Just
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
{ distTargets :: Targets
distTargets = Targets
targets,
distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
}
distributeSingleStm ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore ->
Stm SOACS ->
DistNestT
lore
m
( Maybe
( PostStms lore,
Result,
KernelNest,
DistAcc lore
)
)
distributeSingleStm :: DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd = do
Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms lore)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
Just (Targets
targets, Stms lore
distributed_bnds) ->
Nestings
-> Targets
-> Stm SOACS
-> DistNestT lore m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t lore.
(MonadFreshNames m, HasScope t m, ASTLore lore) =>
Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
bnd DistNestT lore m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall a b. (a -> b) -> a -> b
$
(PostStms lore, Result, KernelNest, DistAcc lore)
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. a -> Maybe a
Just
( Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
distributed_bnds,
Result
res,
KernelNest
new_kernel_nest,
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
{ distTargets :: Targets
distTargets = Targets
targets',
distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
}
)
segmentedScatterKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
PatternT Type ->
Certificates ->
SubExp ->
Lambda lore ->
[VName] ->
[(Shape, Int, VName)] ->
DistNestT lore m (Stms lore)
segmentedScatterKernel :: KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest [Int]
perm PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
dests = do
let nesting :: LoopNesting
nesting =
PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
scatter_pat (Certificates -> Attrs -> () -> StmAux ()
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
forall a. Monoid a => a
mempty ()) SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) [VName]
ivs
nest' :: KernelNest
nest' =
(PatternT Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
scatter_pat, BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) LoopNesting
nesting KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'
let ([Shape]
as_ws, [Int]
as_ns, [VName]
as) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
indexes :: [Int]
indexes = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
[KernelInput]
as_inps <- (VName -> DistNestT lore m KernelInput)
-> [VName] -> DistNestT lore m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
let rts :: [Type]
rts =
([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$
[Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
(Result
is, Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
(Result
is', Stms lore
k_body_stms) <- Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore Result -> DistNestT lore m (Result, Stms lore))
-> Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ())
-> Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
Result
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
is ((SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result)
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall a b. (a -> b) -> a -> b
$ \SubExp
i ->
if Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty
then SubExp -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
i
else Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"scatter_i" (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
i
let k_body :: KernelBody lore
k_body =
[(Shape, Int, KernelInput)]
-> Result -> [(Shape, KernelInput, [(Result, SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ([Shape] -> [Int] -> [KernelInput] -> [(Shape, Int, KernelInput)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns [KernelInput]
as_inps) (Result
is' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
vs)
[(Shape, KernelInput, [(Result, SubExp)])]
-> ([(Shape, KernelInput, [(Result, SubExp)])] -> [KernelResult])
-> [KernelResult]
forall a b. a -> (a -> b) -> b
& ((Shape, KernelInput, [(Result, SubExp)]) -> KernelResult)
-> [(Shape, KernelInput, [(Result, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace)
[KernelResult]
-> ([KernelResult] -> KernelBody lore) -> KernelBody lore
forall a b. a -> (a -> b) -> b
& BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
k_body_stms
(SegOp (SegOpLevel lore) lore
k, Stms lore
k_bnds) <- (Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type]
rts KernelBody lore
k_body
(Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT lore (State VNameSource) () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
k_bnds
let pat :: PatternT Type
pat =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
[Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ())
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
where
findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."
inPlaceReturn :: [(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (Shape
aw, KernelInput
inp, [(Result, SubExp)]
is_vs) =
Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Result
forall a. [a] -> [a]
init Result
ws Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
aw))
(KernelInput -> VName
kernelInputArray KernelInput
inp)
[((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Result -> Slice SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
is, SubExp
v) | (Result
is, SubExp
v) <- [(Result, SubExp)]
is_vs]
where
([VName]
gtids, Result
ws) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace
segmentedUpdateKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
Certificates ->
VName ->
Slice SubExp ->
VName ->
DistNestT lore m (Stms lore)
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certificates
cs VName
arr Slice SubExp
slice VName
v = do
([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
[VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) ([Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")
let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims
((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SubExp
v' <-
Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
[Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"v" (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
Result
slice_is <-
(TPrimExp Int64 VName -> BinderT lore (State VNameSource) SubExp)
-> [TPrimExp Int64 VName]
-> BinderT lore (State VNameSource) Result
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ([Char]
-> TPrimExp Int64 VName -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> BinderT lore (State VNameSource) Result)
-> [TPrimExp Int64 VName]
-> BinderT lore (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
let write_is :: Result
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
slice_is
arr' :: VName
arr' =
VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> VName
forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
(KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
Type
arr_t <- VName -> BinderT lore (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr'
Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
(Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Type
v_t,
Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t) VName
arr' [((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix Result
write_is, SubExp
v')]
)
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <-
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]
(Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms
let pat :: PatternT Type
pat =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
[Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
segmentedGatherKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
Certificates ->
VName ->
Slice SubExp ->
DistNestT lore m (Stms lore)
segmentedGatherKernel :: KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
segmentedGatherKernel KernelNest
nest Certificates
cs VName
arr Slice SubExp
slice = do
let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
[VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) ([Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")
([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims
((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp
slice'' <-
Slice (TPrimExp Int64 VName)
-> BinderT lore (State VNameSource) (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
-> BinderT lore (State VNameSource) (Slice SubExp))
-> Slice (TPrimExp Int64 VName)
-> BinderT lore (State VNameSource) (Slice SubExp)
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
slice) (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
SubExp
v' <- Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"v" (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice''
Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
(Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t, ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify SubExp
v')
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <-
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]
(Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
segmentedHistKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
Certificates ->
SubExp ->
[SOACS.HistOp SOACS] ->
Lambda lore ->
[VName] ->
DistNestT lore m (Stms lore)
segmentedHistKernel :: KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest [Int]
perm Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda lore
lam [VName]
arrs = do
([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let orig_pat :: PatternT Type
orig_pat =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
[Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
[HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) ->
SubExp -> SubExp -> [VName] -> Result -> Lambda -> HistOp SOACS
forall lore.
SubExp -> SubExp -> [VName] -> Result -> Lambda lore -> HistOp lore
SOACS.HistOp SubExp
num_bins SubExp
rf
([VName] -> Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m [VName]
-> DistNestT lore m (Result -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT lore m VName)
-> [VName] -> DistNestT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT lore m KernelInput -> DistNestT lore m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT lore m KernelInput -> DistNestT lore m VName)
-> (VName -> DistNestT lore m KernelInput)
-> VName
-> DistNestT lore m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
DistNestT lore m (Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m Result
-> DistNestT lore m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistNestT lore m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
nes
DistNestT lore m (Lambda -> HistOp SOACS)
-> DistNestT lore m Lambda -> DistNestT lore m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT lore m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
let onLambda' :: Lambda -> BinderT lore m (Lambda lore)
onLambda' = ((Lambda lore, Stms lore) -> Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst (BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore))
-> (Lambda -> BinderT lore m (Lambda lore, Stms lore))
-> Lambda
-> BinderT lore m (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> Lambda
-> BinderT lore m (Lambda lore, Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Binder lore (Lambda lore)
onLambda
m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
BinderT lore m () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
hist_w SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"seghist" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
(Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m))))
-> SegOpLevel (Lore (BinderT lore m))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore (BinderT lore m))
-> [VName]
-> BinderT lore m (Stms (Lore (BinderT lore m)))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
(Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> BinderT lore m (Lambda lore)
Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m)))
onLambda' SegOpLevel lore
SegOpLevel (Lore (BinderT lore m))
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda lore
Lambda (Lore (BinderT lore m))
lam [VName]
arrs
where
findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."
histKernel ::
(MonadBinder m, DistLore (Lore m)) =>
(Lambda SOACS -> m (Lambda (Lore m))) ->
SegOpLevel (Lore m) ->
PatternT Type ->
[(VName, SubExp)] ->
[KernelInput] ->
Certificates ->
SubExp ->
[SOACS.HistOp SOACS] ->
Lambda (Lore m) ->
[VName] ->
m (Stms (Lore m))
histKernel :: (Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> m (Lambda (Lore m))
onLambda SegOpLevel (Lore m)
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Lore m)
lam [VName]
arrs = BinderT (Lore m) m () -> m (Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT (Lore m) m () -> m (Stms (Lore m)))
-> BinderT (Lore m) m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[HistOp (Lore m)]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)])
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) -> do
(Lambda
op', Result
nes', Shape
shape) <- Lambda -> Result -> BinderT (Lore m) m (Lambda, Result, Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
op Result
nes
Lambda (Lore m)
op'' <- m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m)))
-> m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda -> m (Lambda (Lore m))
onLambda Lambda
op'
HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m)))
-> HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda (Lore m)
-> HistOp (Lore m)
forall lore.
SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes' Shape
shape Lambda (Lore m)
op''
let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops'
inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs
Certificates -> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT (Lore m) m () -> BinderT (Lore m) m ())
-> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall a b. (a -> b) -> a -> b
$
Stms (Lore m) -> BinderT (Lore m) m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> BinderT (Lore m) m ())
-> BinderT (Lore m) m (Stms (Lore m)) -> BinderT (Lore m) m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m)))
-> Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m))
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
(Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m)))
-> BinderT (Lore m) m (Stms (Lore m))
-> BinderT (Lore m) m (Stms (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel (Lore m)
-> Pattern (Lore m)
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp (Lore m)]
-> Lambda (Lore m)
-> [VName]
-> BinderT (Lore m) m (Stms (Lore m))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel (Lore m)
lvl PatternT Type
Pattern (Lore m)
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Lore m)]
ops' Lambda (Lore m)
lam [VName]
arrs
determineReduceOp ::
MonadBinder m =>
Lambda SOACS ->
[SubExp] ->
m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
lam Result
nes =
case (SubExp -> Maybe VName) -> Result -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar Result
nes of
Just [VName]
ne_vs' -> do
let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
Result
nes' <- [VName] -> (VName -> m SubExp) -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m Result)
-> (VName -> m SubExp) -> m Result
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
Type
ne_v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ne_v
[Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"hist_ne" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
ne_v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice Type
ne_v_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
Int -> DimIndex SubExp -> Slice SubExp
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> Slice SubExp)
-> DimIndex SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
(Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam', Result
nes', Shape
shape)
Maybe [VName]
Nothing ->
(Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam, Result
nes, Shape
forall a. Monoid a => a
mempty)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
| [Let (Pattern [] [PatElemT (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma w arrs form))] <-
Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam,
Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
[PatElemT (LetDec SOACS)]
pes,
Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
[VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) =
let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
in (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
| Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)
segmentedScanomapKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
SubExp ->
Lambda lore ->
Lambda lore ->
[SubExp] ->
[VName] ->
DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [] ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ -> do
let scan_op :: SegBinOp lore
scan_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segscan" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
(Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
scan_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
regularSegmentedRedomapKernel ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
SubExp ->
Commutativity ->
Lambda lore ->
Lambda lore ->
[SubExp] ->
[VName] ->
DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Commutativity
comm Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [] ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ -> do
let red_op :: SegBinOp lore
red_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
(Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
red_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
isSegmentedOp ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest ->
[Int] ->
Names ->
Names ->
[SubExp] ->
[VName] ->
( PatternT Type ->
[(VName, SubExp)] ->
[KernelInput] ->
[SubExp] ->
[VName] ->
BinderT lore m ()
) ->
DistNestT lore m (Maybe (Stms lore))
isSegmentedOp :: KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op Result
nes [VName]
arrs PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ()
m = MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore)))
-> MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$ do
let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT lore m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
Bool
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ())
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall a b. (a -> b) -> a -> b
$
[Char] -> MaybeT (DistNestT lore m) ()
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Non-fold lambda uses nest-bound parameters."
let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace
prepareNe :: SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe (Var VName
v)
| VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
[Char] -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne
prepareArr :: VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr VName
arr =
case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
Just KernelInput
inp
| KernelInput -> Result
kernelInputIndices KernelInput
inp Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BinderT lore m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT lore m VName) -> VName -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Maybe KernelInput
Nothing
| Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$
[Char] -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp
(VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
(BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
Maybe KernelInput
_ ->
[Char] -> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."
Result
nes' <- (SubExp -> MaybeT (DistNestT lore m) SubExp)
-> Result -> MaybeT (DistNestT lore m) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe Result
nes
[BinderT lore m VName]
mk_arrs <- (VName -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> [VName] -> MaybeT (DistNestT lore m) [BinderT lore m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr [VName]
arrs
DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore))
-> DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall a b. (a -> b) -> a -> b
$
m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
BinderT lore m () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
[VName]
nested_arrs <- [BinderT lore m VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BinderT lore m VName]
mk_arrs
let pat :: PatternT Type
pat =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
[Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ()
m PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps Result
nes' [VName]
nested_arrs
permutationAndMissing :: PatternT Type -> [SubExp] -> Maybe ([Int], [PatElemT Type])
permutationAndMissing :: PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
pat Result
res = do
let pes :: [PatElemT Type]
pes = PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
pat
([PatElemT Type]
_used, [PatElemT Type]
unused) =
(PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes
res_expanded :: Result
res_expanded = Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
unused
[Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
res_expanded
([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)
expandKernelNest ::
MonadFreshNames m =>
[PatElemT Type] ->
KernelNest ->
m KernelNest
expandKernelNest :: [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
let outer_size :: Result
outer_size =
LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:
(LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
inner_sizes :: [Result]
inner_sizes = Result -> [Result]
forall a. [a] -> [[a]]
tails (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
LoopNesting
outer_nest' <- LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
outer_nest Result
outer_size
[LoopNesting]
inner_nests' <- (LoopNesting -> Result -> m LoopNesting)
-> [LoopNesting] -> [Result] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> Result -> m LoopNesting
expandWith [LoopNesting]
inner_nests [Result]
inner_sizes
KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
where
expandWith :: LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
nest Result
dims = do
[PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) dec.
(MonadFreshNames m, Typed dec) =>
Result -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith Result
dims) [PatElemT Type]
pes
LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return
LoopNesting
nest
{ loopNestingPattern :: PatternT Type
loopNestingPattern =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
}
expandPatElemWith :: Result -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith Result
dims PatElemT dec
pe = do
VName
name <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> m VName) -> [Char] -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return
PatElemT dec
pe
{ patElemName :: VName
patElemName = VName
name,
patElemDec :: Type
patElemDec = PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe Type -> Shape -> Type
`arrayOfShape` Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims
}
kernelOrNot ::
(MonadFreshNames m, DistLore lore) =>
Certificates ->
Stm SOACS ->
DistAcc lore ->
PostStms lore ->
DistAcc lore ->
Maybe (Stms lore) ->
DistNestT lore m (DistAcc lore)
kernelOrNot :: Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
_ DistAcc lore
_ Maybe (Stms lore)
Nothing =
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
bnd) DistAcc lore
acc
kernelOrNot Certificates
cs Stm SOACS
_ DistAcc lore
_ PostStms lore
kernels DistAcc lore
acc' (Just Stms lore
bnds) = do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Stm lore) -> Stms lore -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms lore
bnds
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
distributeMap ::
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
MapLoop ->
DistAcc lore ->
DistNestT lore m (DistAcc lore)
distributeMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap (MapLoop Pattern
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc =
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute
(DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
mapNesting
PatternT Type
Pattern
pat
StmAux ()
aux
SubExp
w
Lambda
lam
[VName]
arrs
(DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc' Stms SOACS
lam_bnds)
where
acc' :: DistAcc lore
acc' =
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
{ distTargets :: Targets
distTargets =
(PatternT Type, Result) -> Targets -> Targets
pushInnerTarget
(PatternT Type
Pattern
pat, Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)
(Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc,
distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
}
lam_bnds :: Stms SOACS
lam_bnds = Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam