{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
-- {-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE UndecidableInstances #-} -- tmp show
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-|
   An e-graph efficiently represents a congruence relation over many expressions.

   Based on \"egg: Fast and Extensible Equality Saturation\" https://arxiv.org/abs/2004.03082.
 -}
module Data.Equality.Graph
    (
      -- * Definition of e-graph
      EGraph(..)

    , Memo, Worklist

      -- * Functions on e-graphs
    , emptyEGraph

      -- ** Transformations
    , add, merge, rebuild
    -- , repair, repairAnal

      -- ** Querying
    , find, canonicalize

      -- * Re-exports
    , module Data.Equality.Graph.Classes
    , module Data.Equality.Graph.Nodes
    , module Data.Equality.Language
    ) where

-- import GHC.Conc

import Data.Function

import Data.Functor.Classes

import qualified Data.IntMap.Strict as IM
import qualified Data.Set    as S

import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens

-- | E-graph representing terms of language @l@.
--
-- Intuitively, an e-graph is a set of equivalence classes (e-classes). Each e-class is a
-- set of e-nodes representing equivalent terms from a given language, and an e-node is a function
-- symbol paired with a list of children e-classes.
data EGraph l = EGraph
    { forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind :: !ReprUnionFind           -- ^ Union find like structure to find canonical representation of an e-class id
    , forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes   :: !(ClassIdMap (EClass l)) -- ^ Map canonical e-class ids to their e-classes
    , forall (l :: * -> *). EGraph l -> Memo l
memo      :: !(Memo l)                -- ^ Hashcons maps all canonical e-nodes to their e-class ids
    , forall (l :: * -> *). EGraph l -> Memo l
worklist  :: !(Worklist l)            -- ^ Worklist of e-class ids that need to be upward merged
    , forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist :: !(Worklist l)     -- ^ Like 'worklist' but for analysis repairing
    }

-- | The hashcons 𝐻  is a map from e-nodes to e-class ids
type Memo l = NodeMap l ClassId

-- | Maintained worklist of e-class ids that need to be “upward merged”
type Worklist l = NodeMap l ClassId

-- ROMES:TODO: join things built in paralell?
-- instance Ord1 l => Semigroup (EGraph l) where
--     (<>) eg1 eg2 = undefined -- not so easy
-- instance Ord1 l => Monoid (EGraph l) where
--     mempty = EGraph emptyUF mempty mempty mempty

instance (Show (Domain l), Show1 l) => Show (EGraph l) where
    show :: EGraph l -> String
show (EGraph ReprUnionFind
a ClassIdMap (EClass l)
b Memo l
c Memo l
d Memo l
e) =
        String
"UnionFind: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ReprUnionFind -> String
forall a. Show a => a -> String
show ReprUnionFind
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
            String
"\n\nE-Classes: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ClassIdMap (EClass l) -> String
forall a. Show a => a -> String
show ClassIdMap (EClass l)
b String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
                String
"\n\nHashcons: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
c String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
                    String
"\n\nWorklist: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
d String -> ShowS
forall a. Semigroup a => a -> a -> a
<>
                        String
"\n\nAnalWorklist: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Memo l -> String
forall a. Show a => a -> String
show Memo l
e


-- | Add an e-node to the e-graph
--
-- If the e-node is already represented in this e-graph, the class-id of the
-- class it's already represented in will be returned.
add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l)
add :: forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (Int, EGraph l)
add ENode l
uncanon_e EGraph l
egr =
    let !new_en :: ENode l
new_en = {-# SCC "-2" #-} ENode l -> EGraph l -> ENode l
forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize ENode l
uncanon_e EGraph l
egr

     in case {-# SCC "-1" #-} ENode l -> NodeMap l Int -> Maybe Int
forall (l :: * -> *) a. Ord1 l => ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
new_en (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of
      Just Int
canon_enode_id -> {-# SCC "0" #-} (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
canon_enode_id EGraph l
egr, EGraph l
egr)
      Maybe Int
Nothing ->

        let

            -- Make new equivalence class with a new id in the union-find
            (Int
new_eclass_id, ReprUnionFind
new_uf) = ReprUnionFind -> (Int, ReprUnionFind)
makeNewSet (EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr)

            -- New singleton e-class stores the e-node and its analysis data
            new_eclass :: EClass l
new_eclass       = Int -> Set (ENode l) -> Domain l -> NodeMap l Int -> EClass l
forall (l :: * -> *).
Int -> Set (ENode l) -> Domain l -> NodeMap l Int -> EClass l
EClass Int
new_eclass_id (ENode l -> Set (ENode l)
forall a. a -> Set a
S.singleton ENode l
new_en) (ENode l -> EGraph l -> Domain l
forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
new_en EGraph l
egr) NodeMap l Int
forall a. Monoid a => a
mempty

            -- TODO:Performance: All updates can be done to the map first? Parallelize?
            --
            -- Update e-classes by going through all e-node children and adding
            -- to the e-class parents the new e-node and its e-class id
            --
            -- And add new e-class to existing e-classes
            new_parents :: NodeMap l Int -> NodeMap l Int
new_parents      = ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id
            new_classes :: IntMap (EClass l)
new_classes      = {-# SCC "2" #-} Int -> EClass l -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
new_eclass_id EClass l
new_eclass (IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l) -> IntMap (EClass l)
forall a b. (a -> b) -> a -> b
$
                                    (Int -> IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l) -> l Int -> IntMap (EClass l)
forall a b. (a -> b -> b) -> b -> l a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr  ((EClass l -> EClass l)
-> Int -> IntMap (EClass l) -> IntMap (EClass l)
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IM.adjust ((NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents (forall {f :: * -> *}.
 Functor f =>
 (NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l))
-> (NodeMap l Int -> NodeMap l Int) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ NodeMap l Int -> NodeMap l Int
new_parents))
                                           (EGraph l -> IntMap (EClass l)
forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr)
                                           (ENode l -> l Int
forall (l :: * -> *). ENode l -> l Int
unNode ENode l
new_en)

            -- TODO: From egg: Is this needed?
            -- This is required if we want math pruning to work. Unfortunately, it
            -- also makes the invariants tests x4 slower (because they aren't using
            -- analysis) I think there might be another way to ensure math analysis
            -- pruning to work without having this line here.  Comment it out to
            -- check the result on the unit tests.
            -- 
            -- Update: I found a fix for that case: the modifyA function must add
            -- the parents of the pruned class to the worklist for them to be
            -- upward merged. I think it's a good compromise for requiring the user
            -- to do this. Adding the added node to the worklist everytime creates
            -- too much unnecessary work.
            --
            -- Actually I've found more bugs regarding this, and can't fix them
            -- there, so indeed this seems to be necessary for sanity with 'modifyA'
            --
            -- This way we also liberate the user from caring about the worklist
            --
            -- The hash cons invariants test suffer from this greatly but the
            -- saturation tests seem mostly fine?
            --
            -- And adding to the analysis worklist doesn't work, so maybe it's
            -- something else?
            --
            -- So in the end, we do need to addToWorklist to get correct results
            new_worklist :: NodeMap l Int
new_worklist     = {-# SCC "4" #-} ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr)

            -- Add the e-node's e-class id at the e-node's id
            new_memo :: NodeMap l Int
new_memo         = {-# SCC "5" #-} ENode l -> Int -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
new_en Int
new_eclass_id (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr)

         in ( Int
new_eclass_id

            , EGraph l
egr { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
                  , classes :: IntMap (EClass l)
classes   = IntMap (EClass l)
new_classes
                  , worklist :: NodeMap l Int
worklist  = NodeMap l Int
new_worklist
                  , memo :: NodeMap l Int
memo     = NodeMap l Int
new_memo
                  }

                  -- Modify created node according to analysis
                  EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& {-# SCC "6" #-} Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
new_eclass_id

            )
{-# SCC add #-}

-- | Merge 2 e-classes by id
merge :: forall l. Language l => ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
merge :: forall (l :: * -> *).
Language l =>
Int -> Int -> EGraph l -> (Int, EGraph l)
merge Int
a Int
b EGraph l
egr0 =

  -- Use canonical ids
  let
      a' :: Int
a' = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
a EGraph l
egr0
      b' :: Int
b' = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
b EGraph l
egr0
   in
   if Int
a' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b'
     then (Int
a', EGraph l
egr0)
     else
       let
           -- Get classes being merged
           class_a :: EClass l
class_a = EGraph l
egr0 EGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
a'
           class_b :: EClass l
class_b = EGraph l
egr0 EGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
b'

           -- Leader is the class with more parents
           (Int
leader, EClass l
leader_class, Int
sub, EClass l
sub_class) =
               if (NodeMap l Int -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
sizeNM (EClass l
class_aEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents)) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< (NodeMap l Int -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
sizeNM (EClass l
class_bEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents))
                  then (Int
b', EClass l
class_b, Int
a', EClass l
class_a) -- b is leader
                  else (Int
a', EClass l
class_a, Int
b', EClass l
class_b) -- a is leader

           -- Make leader the leader in the union find
           (Int
new_id, ReprUnionFind
new_uf) = Int -> Int -> ReprUnionFind -> (Int, ReprUnionFind)
unionSets Int
leader Int
sub (EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind EGraph l
egr0)

           -- Update leader class with all e-nodes and parents from the
           -- subsumed class
           updatedLeader :: EClass l
updatedLeader = EClass l
leader_class EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents Lens' (EClass l) (NodeMap l Int)
-> (NodeMap l Int -> NodeMap l Int) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<> EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents)
                                        EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
_nodes   (forall {f :: * -> *}.
 Functor f =>
 (Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l))
-> (Set (ENode l) -> Set (ENode l)) -> EClass l -> EClass l
forall s a. Lens' s a -> (a -> a) -> s -> s
%~ (Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
    Functor f =>
    (Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l))
-> Set (ENode l)
forall s a. s -> Lens' s a -> a
^.(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l))) -> EClass l -> f (EClass l)
_nodes)
                                        EClass l -> (EClass l -> EClass l) -> EClass l
forall a b. a -> (a -> b) -> b
& (Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data    (forall {f :: * -> *}.
 Functor f =>
 (Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l -> EClass l -> EClass l
forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
           new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
leader_classEClass l
-> (forall {f :: * -> *}.
    Functor f =>
    (Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data) (EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
    Functor f =>
    (Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)

           -- Update leader in classes so that it has all nodes and parents
           -- from subsumed class, and delete the subsumed class
           new_classes :: IntMap (EClass l)
new_classes = ((Int -> EClass l -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
leader EClass l
updatedLeader) (IntMap (EClass l) -> IntMap (EClass l))
-> (IntMap (EClass l) -> IntMap (EClass l))
-> IntMap (EClass l)
-> IntMap (EClass l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> IntMap (EClass l) -> IntMap (EClass l)
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
sub)) (EGraph l -> IntMap (EClass l)
forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes EGraph l
egr0)

           -- Add all subsumed parents to worklist We can do this instead of
           -- adding the new e-class itself to the worklist because it would end
           -- up adding its parents anyway
           new_worklist :: NodeMap l Int
new_worklist = EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<> (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr0)

           -- If the new_data is different from the classes, the parents of the
           -- class whose data is different from the merged must be put on the
           -- analysisWorklist
           new_analysis_worklist :: NodeMap l Int
new_analysis_worklist =
             (if Domain l
new_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass l
leader_classEClass l
-> (forall {f :: * -> *}.
    Functor f =>
    (Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)
                then EClass l
leader_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents
                else NodeMap l Int
forall a. Monoid a => a
mempty) NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<>
             (if Domain l
new_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= (EClass l
sub_classEClass l
-> (forall {f :: * -> *}.
    Functor f =>
    (Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall {f :: * -> *}.
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data)
                 then EClass l
sub_classEClass l -> Lens' (EClass l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
Lens' (EClass l) (NodeMap l Int)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents
                 else NodeMap l Int
forall a. Monoid a => a
mempty) NodeMap l Int -> NodeMap l Int -> NodeMap l Int
forall a. Semigroup a => a -> a -> a
<>
             (EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr0)

           -- ROMES:TODO: The code that makes the -1 * cos test pass when some other things are tweaked
           -- new_memo = foldr (`insertNM` leader) (memo egr0) (sub_class^._nodes)

           -- Build new e-graph
           new_egr :: EGraph l
new_egr = EGraph l
egr0
             { unionFind :: ReprUnionFind
unionFind = ReprUnionFind
new_uf
             , classes :: IntMap (EClass l)
classes   = IntMap (EClass l)
new_classes
             -- , memo      = new_memo
             , worklist :: NodeMap l Int
worklist  = NodeMap l Int
new_worklist
             , analysisWorklist :: NodeMap l Int
analysisWorklist = NodeMap l Int
new_analysis_worklist
             }

             -- Modify according to analysis
             EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
new_id

        in (Int
new_id, EGraph l
new_egr)
{-# SCC merge #-}
            

-- | The rebuild operation processes the e-graph's current 'Worklist',
-- restoring the invariants of deduplication and congruence. Rebuilding is
-- similar to other approaches in how it restores congruence; but it uniquely
-- allows the client to choose when to restore invariants in the context of a
-- larger algorithm like equality saturation.
rebuild :: Language l => EGraph l -> EGraph l
rebuild :: forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild (EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Worklist l
mm Worklist l
wl Worklist l
awl) =
  -- empty worklists
  -- repair deduplicated e-classes
  let
    egr' :: EGraph l
egr'  = (ENode l -> Int -> EGraph l -> EGraph l)
-> EGraph l -> Worklist l -> EGraph l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> EGraph l -> EGraph l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repair (ReprUnionFind
-> ClassIdMap (EClass l)
-> Worklist l
-> Worklist l
-> Worklist l
-> EGraph l
forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
EGraph ReprUnionFind
uf ClassIdMap (EClass l)
cls Worklist l
mm Worklist l
forall a. Monoid a => a
mempty Worklist l
forall a. Monoid a => a
mempty) Worklist l
wl
    egr'' :: EGraph l
egr'' = (ENode l -> Int -> EGraph l -> EGraph l)
-> EGraph l -> Worklist l -> EGraph l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> EGraph l -> EGraph l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repairAnal EGraph l
egr' Worklist l
awl
  in
  -- Loop until worklist is completely empty
  if Worklist l -> Bool
forall a. NodeMap l a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
worklist EGraph l
egr'') Bool -> Bool -> Bool
&& Worklist l -> Bool
forall a. NodeMap l a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr'')
     then EGraph l
egr''
     else EGraph l -> EGraph l
forall (l :: * -> *). Language l => EGraph l -> EGraph l
rebuild EGraph l
egr''

{-# SCC rebuild #-}

-- ROMES:TODO: find repair_id could be shared between repair and repairAnal?

-- | Repair a single worklist entry.
repair :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l
repair :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repair ENode l
node Int
repair_id EGraph l
egr =

   case ENode l -> Int -> NodeMap l Int -> (Maybe Int, NodeMap l Int)
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM (ENode l
node ENode l -> EGraph l -> ENode l
forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
`canonicalize` EGraph l
egr) (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
repair_id EGraph l
egr) (ENode l -> NodeMap l Int -> NodeMap l Int
forall (l :: * -> *) a.
Ord1 l =>
ENode l -> NodeMap l a -> NodeMap l a
deleteNM ENode l
node (NodeMap l Int -> NodeMap l Int) -> NodeMap l Int -> NodeMap l Int
forall a b. (a -> b) -> a -> b
$ EGraph l -> NodeMap l Int
forall (l :: * -> *). EGraph l -> Memo l
memo EGraph l
egr) of-- TODO: I seem to really need it. Is find needed? (they don't use it)

      (Maybe Int
Nothing, NodeMap l Int
memo2) -> EGraph l
egr { memo :: NodeMap l Int
memo = NodeMap l Int
memo2 } -- Return new memo but delete uncanonicalized node

      (Just Int
existing_class, NodeMap l Int
memo2) -> (Int, EGraph l) -> EGraph l
forall a b. (a, b) -> b
snd (Int -> Int -> EGraph l -> (Int, EGraph l)
forall (l :: * -> *).
Language l =>
Int -> Int -> EGraph l -> (Int, EGraph l)
merge Int
existing_class Int
repair_id EGraph l
egr{memo :: NodeMap l Int
memo = NodeMap l Int
memo2})
{-# SCC repair #-}

-- | Repair a single analysis-worklist entry.
repairAnal :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l
repairAnal :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> EGraph l -> EGraph l
repairAnal ENode l
node Int
repair_id EGraph l
egr =
    let
        canon_id :: Int
canon_id = Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
find Int
repair_id EGraph l
egr
        c :: EClass l
c        = EGraph l
egrEGraph l -> Lens' (EGraph l) (EClass l) -> EClass l
forall s a. s -> Lens' s a -> a
^.Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
canon_id
        new_data :: Domain l
new_data = forall (l :: * -> *).
Analysis l =>
Domain l -> Domain l -> Domain l
joinA @l (EClass l
cEClass l -> Lens' (EClass l) (Domain l) -> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Domain l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data) (ENode l -> EGraph l -> Domain l
forall (l :: * -> *). Analysis l => ENode l -> EGraph l -> Domain l
makeA ENode l
node EGraph l
egr)
    in
    -- Take action if the new_data is different from the existing data
    if EClass l
cEClass l -> Lens' (EClass l) (Domain l) -> Domain l
forall s a. s -> Lens' s a -> a
^.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Domain l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data Domain l -> Domain l -> Bool
forall a. Eq a => a -> a -> Bool
/= Domain l
new_data
        -- Merge result is different from original class data, update class
        -- with new_data
       then EGraph l
egr { analysisWorklist :: Worklist l
analysisWorklist = EClass l
cEClass l -> Lens' (EClass l) (Worklist l) -> Worklist l
forall s a. s -> Lens' s a -> a
^.(Worklist l -> f (Worklist l)) -> EClass l -> f (EClass l)
Lens' (EClass l) (Worklist l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int)) -> EClass l -> f (EClass l)
_parents Worklist l -> Worklist l -> Worklist l
forall a. Semigroup a => a -> a -> a
<> EGraph l -> Worklist l
forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist EGraph l
egr
                }
                EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> Lens' (EGraph l) (EClass l)
forall (l :: * -> *). Int -> Lens' (EGraph l) (EClass l)
_class Int
canon_id((EClass l -> f (EClass l)) -> EGraph l -> f (EGraph l))
-> ((Domain l -> f (Domain l)) -> EClass l -> f (EClass l))
-> (Domain l -> f (Domain l))
-> EGraph l
-> f (EGraph l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
forall (l :: * -> *) (f :: * -> *).
Functor f =>
(Domain l -> f (Domain l)) -> EClass l -> f (EClass l)
_data (forall {f :: * -> *}.
 Functor f =>
 (Domain l -> f (Domain l)) -> EGraph l -> f (EGraph l))
-> Domain l -> EGraph l -> EGraph l
forall s a. Lens' s a -> a -> s -> s
.~ Domain l
new_data
                EGraph l -> (EGraph l -> EGraph l) -> EGraph l
forall a b. a -> (a -> b) -> b
& Int -> EGraph l -> EGraph l
forall (l :: * -> *). Analysis l => Int -> EGraph l -> EGraph l
modifyA Int
canon_id
       else EGraph l
egr
{-# SCC repairAnal #-}

-- | Canonicalize an e-node
--
-- Two e-nodes are equal when their canonical form is equal. Canonicalization
-- makes the list of e-class ids the e-node holds a list of canonical ids.
-- Meaning two seemingly different e-nodes might be equal when we figure out
-- that their e-class ids are represented by the same e-class canonical ids
--
-- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...)
canonicalize :: Functor l => ENode l -> EGraph l -> ENode l
canonicalize :: forall (l :: * -> *). Functor l => ENode l -> EGraph l -> ENode l
canonicalize (Node l Int
enode) EGraph l
eg = l Int -> ENode l
forall (l :: * -> *). l Int -> ENode l
Node (l Int -> ENode l) -> l Int -> ENode l
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> l Int -> l Int
forall a b. (a -> b) -> l a -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> EGraph l -> Int
forall (l :: * -> *). Int -> EGraph l -> Int
`find` EGraph l
eg) l Int
enode
{-# SCC canonicalize #-}

-- | Find the canonical representation of an e-class id in the e-graph
-- Invariant: The e-class id always exists.
find :: ClassId -> EGraph l -> ClassId
find :: forall (l :: * -> *). Int -> EGraph l -> Int
find Int
cid = Int -> ReprUnionFind -> Int
findRepr Int
cid (ReprUnionFind -> Int)
-> (EGraph l -> ReprUnionFind) -> EGraph l -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph l -> ReprUnionFind
forall (l :: * -> *). EGraph l -> ReprUnionFind
unionFind
{-# INLINE find #-}

-- | The empty e-graph. Nothing is represented in it yet.
emptyEGraph :: Language l => EGraph l
emptyEGraph :: forall (l :: * -> *). Language l => EGraph l
emptyEGraph = ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
forall (l :: * -> *).
ReprUnionFind
-> ClassIdMap (EClass l) -> Memo l -> Memo l -> Memo l -> EGraph l
EGraph ReprUnionFind
emptyUF ClassIdMap (EClass l)
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty Memo l
forall a. Monoid a => a
mempty
{-# INLINE emptyEGraph #-}