{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Equality.Graph
(
EGraph(..)
, Memo, Worklist
, emptyEGraph
, add, merge, rebuild
, find, canonicalize
, module Data.Equality.Graph.Classes
, module Data.Equality.Graph.Nodes
, module Data.Equality.Language
) where
import Data.Function
import Data.Functor.Classes
import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens
data EGraph l = EGraph
{ forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind :: !ReprUnionFind
, forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes :: !(ClassIdMap (EClass l))
, forall (l :: * -> *). EGraph l -> Memo l
memo :: !(Memo l)
, forall (l :: * -> *). EGraph l -> Memo l
worklist :: !(Worklist l)
, forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist :: !(Worklist l)
}
type Memo l = NodeMap l ClassId
type Worklist l = NodeMap l ClassId
instance (Show (Domain l), Show1 l) => Show (EGraph l) where
show :: EGraph l -> String
show (EGraph ReprUnionFind
a ClassIdMap (EClass l)
b Memo l
c Memo l
d Memo l
e) =
String
"UnionFind: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ReprUnionFind -> String
forall a. Show a => a -> String
show ReprUnionFind
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
String
"\n\nE-Classes: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ClassIdMap (EClass l) -> String
forall a. Show a => a -> String
show ClassIdMap (EClass l)
b String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
String
"\n\nHashcons: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
c String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
String
"\n\nWorklist: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
d String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
String
"\n\nAnalWorklist: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
e
add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l)
add :: forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (Int, EGraph l)
add ENode l
uncanon_e EGraph l
egr =
let !new_en :: ENode l
new_en = {-# SCC "-2" #-} ENode l -> EGraph l -> ENode l
forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize ENode l
uncanon_e EGraph l
egr
in case {-# SCC "-1" #-} ENode l -> NodeMap l Int -> Maybe Int
forall (l :: * -> *) a. Ord1 l => ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
Just Int
canon_enode_id -> {-# SCC "0" #-} (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
canon_enode_id EGraph l
egr, EGraph l
egr)
Maybe Int
Nothing ->
let
(Int
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (Int, ReprUnionFind)
makeNewSet (EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr)
new_eclass :: EClass l
new_eclass = Int -> Set (ENode l) -> Domain l -> NodeMap l Int -> EClass l
forall (l :: * -> *).
Int -> Set (ENode l) -> Domain l -> NodeMap l Int -> EClass l
EClass Int
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (ENode l -> EGraph l -> Domain l
forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
new_en EGraph l
egr) NodeMap l Int
forall a. Monoid a => a
mempty
new_parents :: NodeMap l Int -> NodeMap l Int
new_parents = ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id
new_classes :: IntMap (EClass l)
new_classes = {-# SCC "2" #-} Int -> EClass l -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
new_eclass_id EClass l
new_eclass (IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l) -> IntMap (EClass l)
forall a b. (a -> b) -> a -> b
$
(Int -> IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l) -> l Int -> IntMap (EClass l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((EClass l -> EClass l)
-> Int -> IntMap (EClass l) -> IntMap (EClass l)
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IM.adjust ((NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents (forall {f :: * -> *}.
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l))
-> (NodeMap l Int -> NodeMap l Int) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ NodeMap l Int -> NodeMap l Int
new_parents))
(EGraph l -> IntMap (EClass l)
forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr)
(ENode l -> l Int
forall (l :: * -> *). ENode l -> l Int
unNode ENode l
new_en)
new_worklist :: NodeMap l Int
new_worklist = {-# SCC "4" #-} ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr)
new_memo :: NodeMap l Int
new_memo = {-# SCC "5" #-} ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr)
in ( Int
new_eclass_id
, EGraph l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass l)
classes = IntMap (EClass l)
new_classes
, worklist :: NodeMap l Int
worklist = NodeMap l Int
new_worklist
, memo :: NodeMap l Int
memo = NodeMap l Int
new_memo
}
EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& {-# SCC "6" #-} Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
new_eclass_id
)
{-# SCC add #-}
merge :: forall l. Language l => ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge :: forall (l :: * -> *).
Language l =>
Int -> Int -> EGraph l -> (Int, EGraph l)
merge Int
a Int
b EGraph l
egr0 =
let
a' :: Int
a' = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
a EGraph l
egr0
b' :: Int
b' = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
b EGraph l
egr0
in
if Int
a' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b'
then (Int
a', EGraph l
egr0)
else
let
class_a :: EClass l
class_a = EGraph l
egr0 EGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
a'
class_b :: EClass l
class_b = EGraph l
egr0 EGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
b'
(Int
leader, EClass l
leader_class, Int
sub, EClass l
sub_class) =
if (NodeMap l Int -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
sizeNM (EClass l
class_aEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents)) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< (NodeMap l Int -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
sizeNM (EClass l
class_bEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents))
then (Int
b', EClass l
class_b, Int
a', EClass l
class_a)
else (Int
a', EClass l
class_a, Int
b', EClass l
class_b)
(Int
new_id, ReprUnionFind
new_uf) = Int -> Int -> ReprUnionFind -> (Int, ReprUnionFind)
unionSets Int
leader Int
sub (EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr0)
updatedLeader :: EClass l
updatedLeader = EClass l
leader_class EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents Lens' (EClass l) (NodeMap l Int)
-> (NodeMap l Int -> NodeMap l Int) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<> EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents)
EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
_nodes (forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l))
-> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
_nodes)
EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l -> EClass l -> EClass l
forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
leader_classEClass l
-> (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data) (EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)
new_classes :: IntMap (EClass l)
new_classes = ((Int -> EClass l -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
leader EClass l
updatedLeader) (IntMap (EClass l) -> IntMap (EClass l))
-> (IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l)
-> IntMap (EClass l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
sub)) (EGraph l -> IntMap (EClass l)
forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr0)
new_worklist :: NodeMap l Int
new_worklist = EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<> (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr0)
new_analysis_worklist :: NodeMap l Int
new_analysis_worklist =
(if Domain l
new_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass l
leader_classEClass l
-> (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)
then EClass l
leader_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents
else NodeMap l Int
forall a. Monoid a => a
mempty) NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<>
(if Domain l
new_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)
then EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents
else NodeMap l Int
forall a. Monoid a => a
mempty) NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<>
(EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr0)
new_egr :: EGraph l
new_egr = EGraph l
egr0
{ unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
, classes :: IntMap (EClass l)
classes = IntMap (EClass l)
new_classes
, worklist :: NodeMap l Int
worklist = NodeMap l Int
new_worklist
, analysisWorklist :: NodeMap l Int
analysisWorklist = NodeMap l Int
new_analysis_worklist
}
EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
new_id
in (Int
new_id, EGraph l
new_egr)
{-# SCC merge #-}
rebuild :: Language l => EGraph l -> EGraph l
rebuild :: forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Worklist l
mm Worklist l
wl Worklist l
awl) =
let
egr' :: EGraph l
egr' = (ENode l -> Int -> EGraph l -> EGraph l)
-> EGraph l -> Worklist l -> EGraph l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> EGraph l -> EGraph l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repair (ReprUnionFind
-> ClassIdMap (EClass l)
-> Worklist l
-> Worklist l
-> Worklist l
-> EGraph l
forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Worklist l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty) Worklist l
wl
egr'' :: EGraph l
egr'' = (ENode l -> Int -> EGraph l -> EGraph l)
-> EGraph l -> Worklist l -> EGraph l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> EGraph l -> EGraph l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repairAnal EGraph l
egr' Worklist l
awl
in
if Worklist l -> Bool
forall a. NodeMap l a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. NodeMap l a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr'')
then EGraph l
egr''
else EGraph l -> EGraph l
forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild EGraph l
egr''
{-# SCC rebuild #-}
repair :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l
repair :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repair ENode l
node Int
repair_id EGraph l
egr =
case ENode l -> Int -> NodeMap l Int -> (Maybe Int, NodeMap l Int)
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM (ENode l
node ENode l -> EGraph l -> ENode l
forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
`canonicalize` EGraph l
egr) (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
repair_id EGraph l
egr) (ENode l -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> NodeMap l a -> NodeMap l a
deleteNM ENode l
node (NodeMap l Int -> NodeMap l Int) -> NodeMap l Int -> NodeMap l Int
forall a b. (a -> b) -> a -> b
$ EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
(Maybe Int
Nothing, NodeMap l Int
memo2) -> EGraph l
egr { memo :: NodeMap l Int
memo = NodeMap l Int
memo2 }
(Just Int
existing_class, NodeMap l Int
memo2) -> (Int, EGraph l) -> EGraph l
forall a b. (a, b) -> b
snd (Int -> Int -> EGraph l -> (Int, EGraph l)
forall (l :: * -> *).
Language l =>
Int -> Int -> EGraph l -> (Int, EGraph l)
merge Int
existing_class Int
repair_id EGraph l
egr{memo :: NodeMap l Int
memo = NodeMap l Int
memo2})
{-# SCC repair #-}
repairAnal :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l
repairAnal :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repairAnal ENode l
node Int
repair_id EGraph l
egr =
let
canon_id :: Int
canon_id = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
repair_id EGraph l
egr
c :: EClass l
c = EGraph l
egrEGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
canon_id
new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
cEClass l -> Lens' (EClass l) (Domain l) -> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Domain l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data) (ENode l -> EGraph l -> Domain l
forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
node EGraph l
egr)
in
if EClass l
cEClass l -> Lens' (EClass l) (Domain l) -> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Domain l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= Domain l
new_data
then EGraph l
egr { analysisWorklist :: Worklist l
analysisWorklist = EClass l
cEClass l -> Lens' (EClass l) (Worklist l) -> Worklist l
forall s a. s -> Lens' s a -> a
^.(Worklist l -> f (Worklist l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Worklist l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents Worklist l -> Worklist l -> Worklist l
forall a. Semigroup a => a -> a -> a
<> EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr
}
EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
canon_id((EClass l -> f (EClass l)) -> EGraph l -> f (EGraph l))
-> ((Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> (Domain l -> f (Domain l))
-> EGraph l
-> f (EGraph l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data (forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EGraph l -> f (EGraph l))
-> Domain l -> EGraph l -> EGraph l
forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
canon_id
else EGraph l
egr
{-# SCC repairAnal #-}
canonicalize :: Functor l => ENode l -> EGraph l -> ENode l
canonicalize :: forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize (Node l Int
enode) EGraph l
eg = l Int -> ENode l
forall (l :: * -> *). l Int -> ENode l
Node (l Int -> ENode l) -> l Int -> ENode l
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> l Int -> l Int
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
`find` EGraph l
eg) l Int
enode
{-# SCC canonicalize #-}
find :: ClassId -> EGraph l -> ClassId
find :: forall (l :: * -> *). Int -> EGraph l -> Int
find Int
cid = Int -> ReprUnionFind -> Int
findRepr Int
cid (ReprUnionFind -> Int)
-> (EGraph l -> ReprUnionFind) -> EGraph l -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind
{-# INLINE find #-}
emptyEGraph :: Language l => EGraph l
emptyEGraph :: forall (l :: * -> *). Language l => EGraph l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}