{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fprof-auto-top #-}
module Hoopl.Dataflow
( C, O, Block
, lastNode, entryLabel
, foldNodesBwdOO
, foldRewriteNodesBwdOO
, DataflowLattice(..), OldFact(..), NewFact(..), JoinedFact(..)
, TransferFun, RewriteFun
, Fact, FactBase
, getFact, mkFactBase
, analyzeCmmFwd, analyzeCmmBwd
, rewriteCmmBwd
, changedIf
, joinOutFacts
, joinFacts
)
where
import GhcPrelude
import Cmm
import UniqSupply
import Data.Array
import Data.List
import Data.Maybe
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Hoopl.Block
import Hoopl.Graph
import Hoopl.Collections
import Hoopl.Label
type family Fact x f :: *
type instance Fact C f = FactBase f
type instance Fact O f = f
newtype OldFact a = OldFact a
newtype NewFact a = NewFact a
data JoinedFact a
= Changed !a
| NotChanged !a
getJoined :: JoinedFact a -> a
getJoined (Changed a) = a
getJoined (NotChanged a) = a
changedIf :: Bool -> a -> JoinedFact a
changedIf True = Changed
changedIf False = NotChanged
type JoinFun a = OldFact a -> NewFact a -> JoinedFact a
data DataflowLattice a = DataflowLattice
{ fact_bot :: a
, fact_join :: JoinFun a
}
data Direction = Fwd | Bwd
type TransferFun f = CmmBlock -> FactBase f -> FactBase f
type RewriteFun f = CmmBlock -> FactBase f -> UniqSM (CmmBlock, FactBase f)
analyzeCmmBwd, analyzeCmmFwd
:: DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmmBwd = analyzeCmm Bwd
analyzeCmmFwd = analyzeCmm Fwd
analyzeCmm
:: Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmm dir lattice transfer cmmGraph initFact =
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap =
case hooplGraph of
GMany NothingO bm NothingO -> bm
in fixpointAnalysis dir lattice transfer entry blockMap initFact
fixpointAnalysis
:: forall f.
Direction
-> DataflowLattice f
-> TransferFun f
-> Label
-> LabelMap CmmBlock
-> FactBase f
-> FactBase f
fixpointAnalysis direction lattice do_block entry blockmap = loop start
where
blocks = sortBlocks direction entry blockmap
num_blocks = length blocks
block_arr = {-# SCC "block_arr" #-} listArray (0, num_blocks - 1) blocks
start = {-# SCC "start" #-} IntSet.fromDistinctAscList
[0 .. num_blocks - 1]
dep_blocks = {-# SCC "dep_blocks" #-} mkDepBlocks direction blocks
join = fact_join lattice
loop
:: IntHeap
-> FactBase f
-> FactBase f
loop todo !fbase1 | Just (index, todo1) <- IntSet.minView todo =
let block = block_arr ! index
out_facts = {-# SCC "do_block" #-} do_block block fbase1
(todo2, fbase2) = {-# SCC "mapFoldWithKey" #-}
mapFoldlWithKey
(updateFact join dep_blocks) (todo1, fbase1) out_facts
in loop todo2 fbase2
loop _ !fbase1 = fbase1
rewriteCmmBwd
:: DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmmBwd = rewriteCmm Bwd
rewriteCmm
:: Direction
-> DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmm dir lattice rwFun cmmGraph initFact = do
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap1 =
case hooplGraph of
GMany NothingO bm NothingO -> bm
(blockMap2, facts) <-
fixpointRewrite dir lattice rwFun entry blockMap1 initFact
return (cmmGraph {g_graph = GMany NothingO blockMap2 NothingO}, facts)
fixpointRewrite
:: forall f.
Direction
-> DataflowLattice f
-> RewriteFun f
-> Label
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
fixpointRewrite dir lattice do_block entry blockmap = loop start blockmap
where
blocks = sortBlocks dir entry blockmap
num_blocks = length blocks
block_arr = {-# SCC "block_arr_rewrite" #-}
listArray (0, num_blocks - 1) blocks
start = {-# SCC "start_rewrite" #-}
IntSet.fromDistinctAscList [0 .. num_blocks - 1]
dep_blocks = {-# SCC "dep_blocks_rewrite" #-} mkDepBlocks dir blocks
join = fact_join lattice
loop
:: IntHeap
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
loop todo !blocks1 !fbase1
| Just (index, todo1) <- IntSet.minView todo = do
let block = block_arr ! index
(new_block, out_facts) <- {-# SCC "do_block_rewrite" #-}
do_block block fbase1
let blocks2 = mapInsert (entryLabel new_block) new_block blocks1
(todo2, fbase2) = {-# SCC "mapFoldWithKey_rewrite" #-}
mapFoldlWithKey
(updateFact join dep_blocks) (todo1, fbase1) out_facts
loop todo2 blocks2 fbase2
loop _ !blocks1 !fbase1 = return (blocks1, fbase1)
sortBlocks
:: NonLocal n
=> Direction -> Label -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks direction entry blockmap =
case direction of
Fwd -> fwd
Bwd -> reverse fwd
where
fwd = revPostorderFrom blockmap entry
mkDepBlocks :: Direction -> [CmmBlock] -> LabelMap IntSet
mkDepBlocks Fwd blocks = go blocks 0 mapEmpty
where
go [] !_ !dep_map = dep_map
go (b:bs) !n !dep_map =
go bs (n + 1) $ mapInsert (entryLabel b) (IntSet.singleton n) dep_map
mkDepBlocks Bwd blocks = go blocks 0 mapEmpty
where
go [] !_ !dep_map = dep_map
go (b:bs) !n !dep_map =
let insert m l = mapInsertWith IntSet.union l (IntSet.singleton n) m
in go bs (n + 1) $ foldl' insert dep_map (successors b)
updateFact
:: JoinFun f
-> LabelMap IntSet
-> (IntHeap, FactBase f)
-> Label
-> f
-> (IntHeap, FactBase f)
updateFact fact_join dep_blocks (todo, fbase) lbl new_fact
= case lookupFact lbl fbase of
Nothing ->
let !z = mapInsert lbl new_fact fbase in (changed, z)
Just old_fact ->
case fact_join (OldFact old_fact) (NewFact new_fact) of
(NotChanged _) -> (todo, fbase)
(Changed f) -> let !z = mapInsert lbl f fbase in (changed, z)
where
changed = todo `IntSet.union`
mapFindWithDefault IntSet.empty lbl dep_blocks
getFact :: DataflowLattice f -> Label -> FactBase f -> f
getFact lat l fb = case lookupFact l fb of Just f -> f
Nothing -> fact_bot lat
joinOutFacts :: (NonLocal n) => DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts lattice nonLocal fact_base = foldl' join (fact_bot lattice) facts
where
join new old = getJoined $ fact_join lattice (OldFact old) (NewFact new)
facts =
[ fromJust fact
| s <- successors nonLocal
, let fact = lookupFact s fact_base
, isJust fact
]
joinFacts :: DataflowLattice f -> [f] -> f
joinFacts lattice facts = foldl' join (fact_bot lattice) facts
where
join new old = getJoined $ fact_join lattice (OldFact old) (NewFact new)
mkFactBase :: DataflowLattice f -> [(Label, f)] -> FactBase f
mkFactBase lattice = foldl' add mapEmpty
where
join = fact_join lattice
add result (l, f1) =
let !newFact =
case mapLookup l result of
Nothing -> f1
Just f2 -> getJoined $ join (OldFact f1) (NewFact f2)
in mapInsert l newFact result
foldNodesBwdOO :: (CmmNode O O -> f -> f) -> Block CmmNode O O -> f -> f
foldNodesBwdOO funOO = go
where
go (BCat b1 b2) f = go b1 $! go b2 f
go (BSnoc h n) f = go h $! funOO n f
go (BCons n t) f = funOO n $! go t f
go (BMiddle n) f = funOO n f
go BNil f = f
{-# INLINABLE foldNodesBwdOO #-}
foldRewriteNodesBwdOO
:: forall f.
(CmmNode O O -> f -> UniqSM (Block CmmNode O O, f))
-> Block CmmNode O O
-> f
-> UniqSM (Block CmmNode O O, f)
foldRewriteNodesBwdOO rewriteOO initBlock initFacts = go initBlock initFacts
where
go (BCons node1 block1) !fact1 = (rewriteOO node1 `comp` go block1) fact1
go (BSnoc block1 node1) !fact1 = (go block1 `comp` rewriteOO node1) fact1
go (BCat blockA1 blockB1) !fact1 = (go blockA1 `comp` go blockB1) fact1
go (BMiddle node) !fact1 = rewriteOO node fact1
go BNil !fact = return (BNil, fact)
comp rew1 rew2 = \f1 -> do
(b, f2) <- rew2 f1
(a, !f3) <- rew1 f2
let !c = joinBlocksOO a b
return (c, f3)
{-# INLINE comp #-}
{-# INLINABLE foldRewriteNodesBwdOO #-}
joinBlocksOO :: Block n O O -> Block n O O -> Block n O O
joinBlocksOO BNil b = b
joinBlocksOO b BNil = b
joinBlocksOO (BMiddle n) b = blockCons n b
joinBlocksOO b (BMiddle n) = blockSnoc b n
joinBlocksOO b1 b2 = BCat b1 b2
type IntHeap = IntSet