{-# LANGUAGE TupleSections #-}
module Data.Equality.Graph.Monad
( egraph
, represent
, add
, merge
, rebuild
, EG.canonicalize
, EG.find
, EG.emptyEGraph
, EGraphM
, runEGraphM
, EG.EGraph
, modify, get, gets
) where
import Control.Monad ((>=>))
import Control.Monad.Trans.State.Strict
import Data.Equality.Utils (Fix, cata)
import Data.Equality.Graph (EGraph, ClassId, Language, ENode(..))
import qualified Data.Equality.Graph as EG
type EGraphM s = State (EGraph s)
egraph :: Language l => EGraphM l a -> (a, EGraph l)
egraph :: forall (l :: * -> *) a. Language l => EGraphM l a -> (a, EGraph l)
egraph = EGraph l -> EGraphM l a -> (a, EGraph l)
forall (l :: * -> *) a. EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM EGraph l
forall (l :: * -> *). Language l => EGraph l
EG.emptyEGraph
{-# INLINE egraph #-}
represent :: Language l => Fix l -> EGraphM l ClassId
represent :: forall (l :: * -> *). Language l => Fix l -> EGraphM l ClassId
represent = (l (EGraphM l ClassId) -> EGraphM l ClassId)
-> Fix l -> EGraphM l ClassId
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((l (EGraphM l ClassId) -> EGraphM l ClassId)
-> Fix l -> EGraphM l ClassId)
-> (l (EGraphM l ClassId) -> EGraphM l ClassId)
-> Fix l
-> EGraphM l ClassId
forall a b. (a -> b) -> a -> b
$ l (EGraphM l ClassId) -> StateT (EGraph l) Identity (l ClassId)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => l (m a) -> m (l a)
sequence (l (EGraphM l ClassId) -> StateT (EGraph l) Identity (l ClassId))
-> (l ClassId -> EGraphM l ClassId)
-> l (EGraphM l ClassId)
-> EGraphM l ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> ENode l -> EGraphM l ClassId
forall (l :: * -> *). Language l => ENode l -> EGraphM l ClassId
add (ENode l -> EGraphM l ClassId)
-> (l ClassId -> ENode l) -> l ClassId -> EGraphM l ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node
{-# INLINE represent #-}
add :: Language l => ENode l -> EGraphM l ClassId
add :: forall (l :: * -> *). Language l => ENode l -> EGraphM l ClassId
add = (EGraph l -> Identity (ClassId, EGraph l))
-> StateT (EGraph l) Identity ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph l -> Identity (ClassId, EGraph l))
-> StateT (EGraph l) Identity ClassId)
-> (ENode l -> EGraph l -> Identity (ClassId, EGraph l))
-> ENode l
-> StateT (EGraph l) Identity ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ClassId, EGraph l) -> Identity (ClassId, EGraph l))
-> (EGraph l -> (ClassId, EGraph l))
-> EGraph l
-> Identity (ClassId, EGraph l)
forall a b. (a -> b) -> (EGraph l -> a) -> EGraph l -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId, EGraph l) -> Identity (ClassId, EGraph l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((EGraph l -> (ClassId, EGraph l))
-> EGraph l -> Identity (ClassId, EGraph l))
-> (ENode l -> EGraph l -> (ClassId, EGraph l))
-> ENode l
-> EGraph l
-> Identity (ClassId, EGraph l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode l -> EGraph l -> (ClassId, EGraph l)
forall (l :: * -> *).
Language l =>
ENode l -> EGraph l -> (ClassId, EGraph l)
EG.add
{-# INLINE add #-}
merge :: Language l => ClassId -> ClassId -> EGraphM l ClassId
merge :: forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraphM l ClassId
merge ClassId
a ClassId
b = (EGraph l -> Identity (ClassId, EGraph l))
-> StateT (EGraph l) Identity ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((ClassId, EGraph l) -> Identity (ClassId, EGraph l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ClassId, EGraph l) -> Identity (ClassId, EGraph l))
-> (EGraph l -> (ClassId, EGraph l))
-> EGraph l
-> Identity (ClassId, EGraph l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraph l -> (ClassId, EGraph l)
EG.merge ClassId
a ClassId
b)
{-# INLINE merge #-}
rebuild :: Language l => EGraphM l ()
rebuild :: forall (l :: * -> *). Language l => EGraphM l ()
rebuild = (EGraph l -> Identity ((), EGraph l))
-> StateT (EGraph l) Identity ()
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (((), EGraph l) -> Identity ((), EGraph l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (((), EGraph l) -> Identity ((), EGraph l))
-> (EGraph l -> ((), EGraph l))
-> EGraph l
-> Identity ((), EGraph l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((),)(EGraph l -> ((), EGraph l))
-> (EGraph l -> EGraph l) -> EGraph l -> ((), EGraph l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph l -> EGraph l
forall (l :: * -> *). Language l => EGraph l -> EGraph l
EG.rebuild)
{-# INLINE rebuild #-}
runEGraphM :: EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM :: forall (l :: * -> *) a. EGraph l -> EGraphM l a -> (a, EGraph l)
runEGraphM = (EGraphM l a -> EGraph l -> (a, EGraph l))
-> EGraph l -> EGraphM l a -> (a, EGraph l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip EGraphM l a -> EGraph l -> (a, EGraph l)
forall s a. State s a -> s -> (a, s)
runState
{-# INLINE runEGraphM #-}