module Futhark.Optimise.Fusion.Composing
( fuseMaps,
fuseRedomap,
mergeReduceOps,
)
where
import Data.List (mapAccumL)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Binder (Bindable (..), insertStm, insertStms, mkLet)
import Futhark.Construct (mapResult)
import Futhark.IR
import Futhark.Util (dropLast, splitAt3, takeLast)
fuseMaps ::
Bindable lore =>
Names ->
Lambda lore ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda lore ->
[SOAC.Input] ->
(Lambda lore, [SOAC.Input])
fuseMaps :: Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps Names
unfus_nms Lambda lore
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda lore
lam2 [Input]
inp2 = (Lambda lore
lam2', Map Ident Input -> [Input]
forall k a. Map k a -> [a]
M.elems Map Ident Input
inputmap)
where
lam2' :: Lambda lore
lam2' =
Lambda lore
lam2
{ lambdaParams :: [LParam lore]
lambdaParams =
[ VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
name Type
t
| Ident VName
name Type
t <- [Ident]
lam2redparams [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ Map Ident Input -> [Ident]
forall k a. Map k a -> [k]
M.keys Map Ident Input
inputmap
],
lambdaBody :: BodyT lore
lambdaBody = BodyT lore
new_body2'
}
new_body2 :: BodyT lore
new_body2 =
let bnds :: [SubExp] -> [Stm lore]
bnds [SubExp]
res =
[ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
| (Ident
p, SubExp
e) <- [Ident] -> [SubExp] -> [(Ident, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
pat [SubExp]
res
]
bindLambda :: [SubExp] -> BodyT lore
bindLambda [SubExp]
res =
[Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([SubExp] -> [Stm lore]
forall lore. Bindable lore => [SubExp] -> [Stm lore]
bnds [SubExp]
res) Stms lore -> BodyT lore -> BodyT lore
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
`insertStms` BodyT lore -> BodyT lore
makeCopiesInner (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)
in BodyT lore -> BodyT lore
makeCopies (BodyT lore -> BodyT lore) -> BodyT lore -> BodyT lore
forall a b. (a -> b) -> a -> b
$ ([SubExp] -> BodyT lore) -> BodyT lore -> BodyT lore
forall lore.
Bindable lore =>
([SubExp] -> Body lore) -> Body lore -> Body lore
mapResult [SubExp] -> BodyT lore
bindLambda (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1)
new_body2_rses :: [SubExp]
new_body2_rses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
new_body2
new_body2' :: BodyT lore
new_body2' =
BodyT lore
new_body2
{ bodyResult :: [SubExp]
bodyResult =
[SubExp]
new_body2_rses
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Ident -> SubExp) -> [Ident] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
unfus_pat
}
([Ident]
lam2redparams, [Ident]
unfus_pat, [Ident]
pat, Map Ident Input
inputmap, BodyT lore -> BodyT lore
makeCopies, BodyT lore -> BodyT lore
makeCopiesInner) =
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
BodyT lore -> BodyT lore, BodyT lore -> BodyT lore)
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
Body lore -> Body lore, Body lore -> Body lore)
fuseInputs Names
unfus_nms Lambda lore
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda lore
lam2 [Input]
inp2
fuseInputs ::
Bindable lore =>
Names ->
Lambda lore ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda lore ->
[SOAC.Input] ->
( [Ident],
[Ident],
[Ident],
M.Map Ident SOAC.Input,
Body lore -> Body lore,
Body lore -> Body lore
)
fuseInputs :: Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
Body lore -> Body lore, Body lore -> Body lore)
fuseInputs Names
unfus_nms Lambda lore
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda lore
lam2 [Input]
inp2 =
([Ident]
lam2redparams, [Ident]
unfus_vars, [Ident]
outbnds, Map Ident Input
inputmap, Body lore -> Body lore
makeCopies, Body lore -> Body lore
makeCopiesInner)
where
([Ident]
lam2redparams, [Ident]
lam2arrparams) =
Int -> [Ident] -> ([Ident], [Ident])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
lam2params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
inp2) [Ident]
lam2params
lam1params :: [Ident]
lam1params = (Param (LParamInfo lore) -> Ident)
-> [Param (LParamInfo lore)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent ([Param (LParamInfo lore)] -> [Ident])
-> [Param (LParamInfo lore)] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam1
lam2params :: [Ident]
lam2params = (Param (LParamInfo lore) -> Ident)
-> [Param (LParamInfo lore)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent ([Param (LParamInfo lore)] -> [Ident])
-> [Param (LParamInfo lore)] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam2
lam1inputmap :: Map Ident Input
lam1inputmap = [(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam1params [Input]
inp1
lam2inputmap :: Map Ident Input
lam2inputmap = [(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
(Map Ident Input
lam2inputmap', Body lore -> Body lore
makeCopiesInner) = Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
forall lore.
Bindable lore =>
Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
removeDuplicateInputs Map Ident Input
lam2inputmap
originputmap :: Map Ident Input
originputmap = Map Ident Input
lam1inputmap Map Ident Input -> Map Ident Input -> Map Ident Input
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map Ident Input
lam2inputmap'
outins :: Map Ident Input
outins =
([Ident] -> [Input] -> Map Ident Input)
-> ([Ident], [Input]) -> Map Ident Input
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([VName] -> [Ident] -> [Input] -> Map Ident Input
outParams ([VName] -> [Ident] -> [Input] -> Map Ident Input)
-> [VName] -> [Ident] -> [Input] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ ((VName, Ident) -> VName) -> [(VName, Ident)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Ident) -> VName
forall a b. (a, b) -> a
fst [(VName, Ident)]
out1) (([Ident], [Input]) -> Map Ident Input)
-> ([Ident], [Input]) -> Map Ident Input
forall a b. (a -> b) -> a -> b
$
[(Ident, Input)] -> ([Ident], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Ident, Input)] -> ([Ident], [Input]))
-> [(Ident, Input)] -> ([Ident], [Input])
forall a b. (a -> b) -> a -> b
$ Map Ident Input -> [(Ident, Input)]
forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
lam2inputmap'
outbnds :: [Ident]
outbnds = [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins
(Map Ident Input
inputmap, Body lore -> Body lore
makeCopies) =
Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
forall lore.
Bindable lore =>
Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
removeDuplicateInputs (Map Ident Input -> (Map Ident Input, Body lore -> Body lore))
-> Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
forall a b. (a -> b) -> a -> b
$ Map Ident Input
originputmap Map Ident Input -> Map Ident Input -> Map Ident Input
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Map Ident Input
outins
getVarParPair :: (b, Input) -> Maybe (VName, b)
getVarParPair (b, Input)
x = case Input -> Maybe VName
SOAC.isVarInput ((b, Input) -> Input
forall a b. (a, b) -> b
snd (b, Input)
x) of
Just VName
nm -> (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
nm, (b, Input) -> b
forall a b. (a, b) -> a
fst (b, Input)
x)
Maybe VName
Nothing -> Maybe (VName, b)
forall a. Maybe a
Nothing
outinsrev :: Map VName Ident
outinsrev = [(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ ((Ident, Input) -> Maybe (VName, Ident))
-> [(Ident, Input)] -> [(VName, Ident)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Ident, Input) -> Maybe (VName, Ident)
forall b. (b, Input) -> Maybe (VName, b)
getVarParPair ([(Ident, Input)] -> [(VName, Ident)])
-> [(Ident, Input)] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ Map Ident Input -> [(Ident, Input)]
forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
outins
unfusible :: VName -> Maybe Ident
unfusible VName
outname
| VName
outname VName -> Names -> Bool
`nameIn` Names
unfus_nms =
VName
outname VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName Ident -> Map VName Ident -> Map VName Ident
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map VName Ident
outinsrev ([(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Ident)]
out1)
unfusible VName
_ = Maybe Ident
forall a. Maybe a
Nothing
unfus_vars :: [Ident]
unfus_vars = ((VName, Ident) -> Maybe Ident) -> [(VName, Ident)] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Maybe Ident
unfusible (VName -> Maybe Ident)
-> ((VName, Ident) -> VName) -> (VName, Ident) -> Maybe Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Ident) -> VName
forall a b. (a, b) -> a
fst) [(VName, Ident)]
out1
outParams ::
[VName] ->
[Ident] ->
[SOAC.Input] ->
M.Map Ident SOAC.Input
outParams :: [VName] -> [Ident] -> [Input] -> Map Ident Input
outParams [VName]
out1 [Ident]
lam2arrparams [Input]
inp2 =
[(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ ((Ident, Input) -> Maybe (Ident, Input))
-> [(Ident, Input)] -> [(Ident, Input)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Ident, Input) -> Maybe (Ident, Input)
forall a. (a, Input) -> Maybe (a, Input)
isOutParam ([(Ident, Input)] -> [(Ident, Input)])
-> [(Ident, Input)] -> [(Ident, Input)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
where
isOutParam :: (a, Input) -> Maybe (a, Input)
isOutParam (a
p, Input
inp)
| Just VName
a <- Input -> Maybe VName
SOAC.isVarInput Input
inp,
VName
a VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
out1 =
(a, Input) -> Maybe (a, Input)
forall a. a -> Maybe a
Just (a
p, Input
inp)
isOutParam (a, Input)
_ = Maybe (a, Input)
forall a. Maybe a
Nothing
filterOutParams ::
[(VName, Ident)] ->
M.Map Ident SOAC.Input ->
[Ident]
filterOutParams :: [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins =
(Map VName [Ident], [Ident]) -> [Ident]
forall a b. (a, b) -> b
snd ((Map VName [Ident], [Ident]) -> [Ident])
-> (Map VName [Ident], [Ident]) -> [Ident]
forall a b. (a -> b) -> a -> b
$ (Map VName [Ident] -> (VName, Ident) -> (Map VName [Ident], Ident))
-> Map VName [Ident]
-> [(VName, Ident)]
-> (Map VName [Ident], [Ident])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Map VName [Ident] -> (VName, Ident) -> (Map VName [Ident], Ident)
forall k b. Ord k => Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map VName [Ident]
outUsage [(VName, Ident)]
out1
where
outUsage :: Map VName [Ident]
outUsage = (Map VName [Ident] -> Ident -> Input -> Map VName [Ident])
-> Map VName [Ident] -> Map Ident Input -> Map VName [Ident]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' Map VName [Ident] -> Ident -> Input -> Map VName [Ident]
forall a. Map VName [a] -> a -> Input -> Map VName [a]
add Map VName [Ident]
forall k a. Map k a
M.empty Map Ident Input
outins
where
add :: Map VName [a] -> a -> Input -> Map VName [a]
add Map VName [a]
m a
p Input
inp =
case Input -> Maybe VName
SOAC.isVarInput Input
inp of
Just VName
v -> ([a] -> [a] -> [a])
-> VName -> [a] -> Map VName [a] -> Map VName [a]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++) VName
v [a
p] Map VName [a]
m
Maybe VName
Nothing -> Map VName [a]
m
checkUsed :: Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map k [b]
m (k
a, b
ra) =
case k -> Map k [b] -> Maybe [b]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
a Map k [b]
m of
Just (b
p : [b]
ps) -> (k -> [b] -> Map k [b] -> Map k [b]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
a [b]
ps Map k [b]
m, b
p)
Maybe [b]
_ -> (Map k [b]
m, b
ra)
removeDuplicateInputs ::
Bindable lore =>
M.Map Ident SOAC.Input ->
(M.Map Ident SOAC.Input, Body lore -> Body lore)
removeDuplicateInputs :: Map Ident Input -> (Map Ident Input, Body lore -> Body lore)
removeDuplicateInputs = ((Map Ident Input, Body lore -> Body lore), Map Input VName)
-> (Map Ident Input, Body lore -> Body lore)
forall a b. (a, b) -> a
fst (((Map Ident Input, Body lore -> Body lore), Map Input VName)
-> (Map Ident Input, Body lore -> Body lore))
-> (Map Ident Input
-> ((Map Ident Input, Body lore -> Body lore), Map Input VName))
-> Map Ident Input
-> (Map Ident Input, Body lore -> Body lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Map Ident Input, Body lore -> Body lore), Map Input VName)
-> Ident
-> Input
-> ((Map Ident Input, Body lore -> Body lore), Map Input VName))
-> ((Map Ident Input, Body lore -> Body lore), Map Input VName)
-> Map Ident Input
-> ((Map Ident Input, Body lore -> Body lore), Map Input VName)
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' ((Map Ident Input, Body lore -> Body lore), Map Input VName)
-> Ident
-> Input
-> ((Map Ident Input, Body lore -> Body lore), Map Input VName)
forall lore k c.
(Bindable lore, Ord k) =>
((Map Ident k, Body lore -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body lore -> c), Map k VName)
comb ((Map Ident Input
forall k a. Map k a
M.empty, Body lore -> Body lore
forall a. a -> a
id), Map Input VName
forall k a. Map k a
M.empty)
where
comb :: ((Map Ident k, Body lore -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body lore -> c), Map k VName)
comb ((Map Ident k
parmap, Body lore -> c
inner), Map k VName
arrmap) Ident
par k
arr =
case k -> Map k VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
arr Map k VName
arrmap of
Maybe VName
Nothing ->
( (Ident -> k -> Map Ident k -> Map Ident k
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Ident
par k
arr Map Ident k
parmap, Body lore -> c
inner),
k -> VName -> Map k VName -> Map k VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
arr (Ident -> VName
identName Ident
par) Map k VName
arrmap
)
Just VName
par' ->
( (Map Ident k
parmap, Body lore -> c
inner (Body lore -> c) -> (Body lore -> Body lore) -> Body lore -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName -> Body lore -> Body lore
forall lore.
Bindable lore =>
Ident -> VName -> Body lore -> Body lore
forward Ident
par VName
par'),
Map k VName
arrmap
)
forward :: Ident -> VName -> Body lore -> Body lore
forward Ident
to VName
from Body lore
b =
[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
to] (BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
from)
Stm lore -> Body lore -> Body lore
forall lore. Bindable lore => Stm lore -> Body lore -> Body lore
`insertStm` Body lore
b
fuseRedomap ::
Bindable lore =>
Names ->
[VName] ->
Lambda lore ->
[SubExp] ->
[SubExp] ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda lore ->
[SubExp] ->
[SubExp] ->
[SOAC.Input] ->
(Lambda lore, [SOAC.Input])
fuseRedomap :: Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
Names
unfus_nms
[VName]
outVars
Lambda lore
p_lam
[SubExp]
p_scan_nes
[SubExp]
p_red_nes
[Input]
p_inparr
[(VName, Ident)]
outPairs
Lambda lore
c_lam
[SubExp]
c_scan_nes
[SubExp]
c_red_nes
[Input]
c_inparr =
let p_num_nes :: Int
p_num_nes = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes
unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_nms) [VName]
outVars
p_lam_body :: BodyT lore
p_lam_body = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
p_lam
([Type]
p_lam_scan_ts, [Type]
p_lam_red_ts, [Type]
p_lam_map_ts) =
Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
p_lam
([SubExp]
p_lam_scan_res, [SubExp]
p_lam_red_res, [SubExp]
p_lam_map_res) =
Int -> Int -> [SubExp] -> ([SubExp], [SubExp], [SubExp])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) ([SubExp] -> ([SubExp], [SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
p_lam_body
p_lam_hacked :: Lambda lore
p_lam_hacked =
Lambda lore
p_lam
{ lambdaParams :: [LParam lore]
lambdaParams = Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
takeLast ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) ([LParam lore] -> [LParam lore]) -> [LParam lore] -> [LParam lore]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
p_lam,
lambdaBody :: BodyT lore
lambdaBody = BodyT lore
p_lam_body {bodyResult :: [SubExp]
bodyResult = [SubExp]
p_lam_map_res},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
p_lam_map_ts
}
(Lambda lore
res_lam, [Input]
new_inp) =
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps
([VName] -> Names
namesFromList [VName]
unfus_arrs)
Lambda lore
p_lam_hacked
[Input]
p_inparr
(Int -> [(VName, Ident)] -> [(VName, Ident)]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes [(VName, Ident)]
outPairs)
Lambda lore
c_lam
[Input]
c_inparr
([Type]
res_lam_scan_ts, [Type]
res_lam_red_ts, [Type]
res_lam_map_ts) =
Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
res_lam
([VName]
_, [Type]
extra_map_ts) =
[(VName, Type)] -> ([VName], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Type)] -> ([VName], [Type]))
-> [(VName, Type)] -> ([VName], [Type])
forall a b. (a -> b) -> a -> b
$
((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName
nm, Type
_) -> VName
nm VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
unfus_arrs) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes [VName]
outVars) ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
p_lam
accpars :: [LParam lore]
accpars = Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
dropLast ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) ([LParam lore] -> [LParam lore]) -> [LParam lore] -> [LParam lore]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
p_lam
res_body :: BodyT lore
res_body = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
res_lam
([SubExp]
res_lam_scan_res, [SubExp]
res_lam_red_res, [SubExp]
res_lam_map_res) =
Int -> Int -> [SubExp] -> ([SubExp], [SubExp], [SubExp])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) ([SubExp] -> ([SubExp], [SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
res_body
res_body' :: BodyT lore
res_body' =
BodyT lore
res_body
{ bodyResult :: [SubExp]
bodyResult =
[SubExp]
p_lam_scan_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
res_lam_scan_res
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
p_lam_red_res
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
res_lam_red_res
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
res_lam_map_res
}
res_lam' :: Lambda lore
res_lam' =
Lambda lore
res_lam
{ lambdaParams :: [LParam lore]
lambdaParams = [LParam lore]
accpars [LParam lore] -> [LParam lore] -> [LParam lore]
forall a. [a] -> [a] -> [a]
++ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
res_lam,
lambdaBody :: BodyT lore
lambdaBody = BodyT lore
res_body',
lambdaReturnType :: [Type]
lambdaReturnType =
[Type]
p_lam_scan_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_scan_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
p_lam_red_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_red_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_map_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
extra_map_ts
}
in (Lambda lore
res_lam', [Input]
new_inp)
mergeReduceOps :: Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps :: Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps (Lambda [LParam lore]
par1 BodyT lore
bdy1 [Type]
rtp1) (Lambda [LParam lore]
par2 BodyT lore
bdy2 [Type]
rtp2) =
let body' :: BodyT lore
body' =
BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
(BodyT lore -> BodyDec lore
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT lore
bdy1)
(BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
bdy1 Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
bdy2)
(BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
bdy1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
bdy2)
(Int
len1, Int
len2) = ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rtp1, [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rtp2)
par' :: [LParam lore]
par' = Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
take Int
len1 [LParam lore]
par1 [LParam lore] -> [LParam lore] -> [LParam lore]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
take Int
len2 [LParam lore]
par2 [LParam lore] -> [LParam lore] -> [LParam lore]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
drop Int
len1 [LParam lore]
par1 [LParam lore] -> [LParam lore] -> [LParam lore]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam lore] -> [LParam lore]
forall a. Int -> [a] -> [a]
drop Int
len2 [LParam lore]
par2
in [LParam lore] -> BodyT lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
par' BodyT lore
body' ([Type]
rtp1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rtp2)