{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BlockArguments #-}
module Data.Equality.Saturation
(
equalitySaturation, equalitySaturation', runEqualitySaturation
, Rewrite(..), RewriteCondition
, CostFunction
, Fix(..), cata
) where
import qualified Data.IntMap.Strict as IM
import Data.Bifunctor
import Control.Monad
import Data.Equality.Utils
import Data.Equality.Graph.Nodes
import Data.Equality.Graph.Lens
import qualified Data.Equality.Graph as G
import Data.Equality.Graph.Monad
import Data.Equality.Language
import Data.Equality.Analysis
import Data.Equality.Graph.Classes
import Data.Equality.Matching
import Data.Equality.Matching.Database
import Data.Equality.Extraction
import Data.Equality.Saturation.Rewrites
import Data.Equality.Saturation.Scheduler
equalitySaturation :: forall a l cost
. (Analysis a l, Language l, Ord cost)
=> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation :: forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
Fix l
-> [Rewrite a l] -> CostFunction l cost -> (Fix l, EGraph a l)
equalitySaturation = BackoffScheduler
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
forall a (l :: * -> *) schd cost.
(Analysis a l, Language l, Scheduler schd, Ord cost) =>
schd
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation' BackoffScheduler
defaultBackoffScheduler
equalitySaturation' :: forall a l schd cost
. (Analysis a l, Language l, Scheduler schd, Ord cost)
=> schd
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation' :: forall a (l :: * -> *) schd cost.
(Analysis a l, Language l, Scheduler schd, Ord cost) =>
schd
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation' schd
schd Fix l
expr [Rewrite a l]
rewrites CostFunction l cost
cost = EGraphM a l (Fix l) -> (Fix l, EGraph a l)
forall (l :: * -> *) anl a.
Language l =>
EGraphM anl l a -> (a, EGraph anl l)
egraph (EGraphM a l (Fix l) -> (Fix l, EGraph a l))
-> EGraphM a l (Fix l) -> (Fix l, EGraph a l)
forall a b. (a -> b) -> a -> b
$ do
ClassId
origClass <- Fix l -> EGraphM a l ClassId
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
Fix l -> EGraphM anl l ClassId
represent Fix l
expr
schd -> [Rewrite a l] -> EGraphM a l ()
forall a (l :: * -> *) schd.
(Analysis a l, Language l, Scheduler schd) =>
schd -> [Rewrite a l] -> EGraphM a l ()
runEqualitySaturation schd
schd [Rewrite a l]
rewrites
(EGraph a l -> Fix l) -> EGraphM a l (Fix l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph a l -> Fix l) -> EGraphM a l (Fix l))
-> (EGraph a l -> Fix l) -> EGraphM a l (Fix l)
forall a b. (a -> b) -> a -> b
$ \EGraph a l
g -> EGraph a l -> CostFunction l cost -> ClassId -> Fix l
forall anl (lang :: * -> *) cost.
(Language lang, Ord cost) =>
EGraph anl lang -> CostFunction lang cost -> ClassId -> Fix lang
extractBest EGraph a l
g CostFunction l cost
cost ClassId
origClass
{-# INLINABLE equalitySaturation' #-}
runEqualitySaturation :: forall a l schd
. (Analysis a l, Language l, Scheduler schd)
=> schd
-> [Rewrite a l]
-> EGraphM a l ()
runEqualitySaturation :: forall a (l :: * -> *) schd.
(Analysis a l, Language l, Scheduler schd) =>
schd -> [Rewrite a l] -> EGraphM a l ()
runEqualitySaturation schd
schd [Rewrite a l]
rewrites = ClassId -> IntMap (Stat schd) -> EGraphM a l ()
runEqualitySaturation' ClassId
0 IntMap (Stat schd)
forall a. Monoid a => a
mempty where
runEqualitySaturation' :: Int -> IM.IntMap (Stat schd) -> EGraphM a l ()
runEqualitySaturation' :: ClassId -> IntMap (Stat schd) -> EGraphM a l ()
runEqualitySaturation' ClassId
30 IntMap (Stat schd)
_ = () -> EGraphM a l ()
forall a. a -> StateT (EGraph a l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
runEqualitySaturation' ClassId
i IntMap (Stat schd)
stats = do
EGraph a l
egr <- StateT (EGraph a l) Identity (EGraph a l)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let (NodeMap l ClassId
beforeMemo, ClassIdMap (EClass a l)
beforeClasses) = (EGraph a l
egrEGraph a l
-> Lens' (EGraph a l) (NodeMap l ClassId) -> NodeMap l ClassId
forall s a. s -> Lens' s a -> a
^.(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (NodeMap l ClassId)
_memo, EGraph a l
egrEGraph a l
-> Lens' (EGraph a l) (ClassIdMap (EClass a l))
-> ClassIdMap (EClass a l)
forall s a. s -> Lens' s a -> a
^.(ClassIdMap (EClass a l) -> f (ClassIdMap (EClass a l)))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(ClassIdMap (EClass a l) -> f (ClassIdMap (EClass a l)))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (ClassIdMap (EClass a l))
_classes)
db :: Database l
db = EGraph a l -> Database l
forall (l :: * -> *) a. Language l => EGraph a l -> Database l
eGraphToDatabase EGraph a l
egr
let (![(Rewrite a l, Match)]
matches, IntMap (Stat schd)
newStats) = [([(Rewrite a l, Match)], IntMap (Stat schd))]
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
forall a. Monoid a => [a] -> a
mconcat (((ClassId, Rewrite a l)
-> ([(Rewrite a l, Match)], IntMap (Stat schd)))
-> [(ClassId, Rewrite a l)]
-> [([(Rewrite a l, Match)], IntMap (Stat schd))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite a l)
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats) ([ClassId] -> [Rewrite a l] -> [(ClassId, Rewrite a l)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ClassId
1..] [Rewrite a l]
rewrites))
[(Rewrite a l, Match)]
-> ((Rewrite a l, Match) -> EGraphM a l ()) -> EGraphM a l ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Rewrite a l, Match)]
matches (Rewrite a l, Match) -> EGraphM a l ()
applyMatchesRhs
EGraphM a l ()
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
EGraphM anl l ()
rebuild
(NodeMap l ClassId
afterMemo, ClassIdMap (EClass a l)
afterClasses) <- (EGraph a l -> (NodeMap l ClassId, ClassIdMap (EClass a l)))
-> StateT
(EGraph a l) Identity (NodeMap l ClassId, ClassIdMap (EClass a l))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (\EGraph a l
g -> (EGraph a l
gEGraph a l
-> Lens' (EGraph a l) (NodeMap l ClassId) -> NodeMap l ClassId
forall s a. s -> Lens' s a -> a
^.(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (NodeMap l ClassId)
_memo, EGraph a l
gEGraph a l
-> Lens' (EGraph a l) (ClassIdMap (EClass a l))
-> ClassIdMap (EClass a l)
forall s a. s -> Lens' s a -> a
^.(ClassIdMap (EClass a l) -> f (ClassIdMap (EClass a l)))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(ClassIdMap (EClass a l) -> f (ClassIdMap (EClass a l)))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (ClassIdMap (EClass a l))
_classes))
Bool -> EGraphM a l () -> EGraphM a l ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (NodeMap l ClassId -> ClassId
forall (l :: * -> *) a. NodeMap l a -> ClassId
G.sizeNM NodeMap l ClassId
afterMemo ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== NodeMap l ClassId -> ClassId
forall (l :: * -> *) a. NodeMap l a -> ClassId
G.sizeNM NodeMap l ClassId
beforeMemo
Bool -> Bool -> Bool
&& ClassIdMap (EClass a l) -> ClassId
forall a. IntMap a -> ClassId
IM.size ClassIdMap (EClass a l)
afterClasses ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdMap (EClass a l) -> ClassId
forall a. IntMap a -> ClassId
IM.size ClassIdMap (EClass a l)
beforeClasses)
(ClassId -> IntMap (Stat schd) -> EGraphM a l ()
runEqualitySaturation' (ClassId
iClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1) IntMap (Stat schd)
newStats)
matchWithScheduler :: Database l -> Int -> IM.IntMap (Stat schd) -> (Int, Rewrite a l) -> ([(Rewrite a l, Match)], IM.IntMap (Stat schd))
matchWithScheduler :: Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite a l)
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats = \case
(ClassId
rw_id, Rewrite a l
rw :| RewriteCondition a l
cnd) -> ([(Rewrite a l, Match)] -> [(Rewrite a l, Match)])
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (((Rewrite a l, Match) -> (Rewrite a l, Match))
-> [(Rewrite a l, Match)] -> [(Rewrite a l, Match)]
forall a b. (a -> b) -> [a] -> [b]
map ((Rewrite a l -> Rewrite a l)
-> (Rewrite a l, Match) -> (Rewrite a l, Match)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Rewrite a l -> RewriteCondition a l -> Rewrite a l
forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| RewriteCondition a l
cnd))) (([(Rewrite a l, Match)], IntMap (Stat schd))
-> ([(Rewrite a l, Match)], IntMap (Stat schd)))
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
forall a b. (a -> b) -> a -> b
$ Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite a l)
-> ([(Rewrite a l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats (ClassId
rw_id, Rewrite a l
rw)
(ClassId
rw_id, Pattern l
lhs := Pattern l
rhs) -> do
case ClassId -> IntMap (Stat schd) -> Maybe (Stat schd)
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
rw_id IntMap (Stat schd)
stats of
Just Stat schd
s | forall s. Scheduler s => ClassId -> Stat s -> Bool
isBanned @schd ClassId
i Stat schd
s -> ([], IntMap (Stat schd)
stats)
Maybe (Stat schd)
x -> do
let matches' :: [Match]
matches' = Database l -> Pattern l -> [Match]
forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
lhs
let newStats :: IntMap (Stat schd)
newStats = schd
-> ClassId
-> ClassId
-> Maybe (Stat schd)
-> IntMap (Stat schd)
-> [Match]
-> IntMap (Stat schd)
forall s.
Scheduler s =>
s
-> ClassId
-> ClassId
-> Maybe (Stat s)
-> IntMap (Stat s)
-> [Match]
-> IntMap (Stat s)
updateStats schd
schd ClassId
i ClassId
rw_id Maybe (Stat schd)
x IntMap (Stat schd)
stats [Match]
matches'
((Match -> (Rewrite a l, Match))
-> [Match] -> [(Rewrite a l, Match)]
forall a b. (a -> b) -> [a] -> [b]
map (Pattern l
lhs Pattern l -> Pattern l -> Rewrite a l
forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern l
rhs,) [Match]
matches', IntMap (Stat schd)
newStats)
applyMatchesRhs :: (Rewrite a l, Match) -> EGraphM a l ()
applyMatchesRhs :: (Rewrite a l, Match) -> EGraphM a l ()
applyMatchesRhs =
\case
(Rewrite a l
rw :| RewriteCondition a l
cond, m :: Match
m@(Match Subst
subst ClassId
_)) -> do
EGraph a l
egr <- StateT (EGraph a l) Identity (EGraph a l)
forall (m :: * -> *) s. Monad m => StateT s m s
get
Bool -> EGraphM a l () -> EGraphM a l ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RewriteCondition a l
cond Subst
subst EGraph a l
egr) (EGraphM a l () -> EGraphM a l ())
-> EGraphM a l () -> EGraphM a l ()
forall a b. (a -> b) -> a -> b
$
(Rewrite a l, Match) -> EGraphM a l ()
applyMatchesRhs (Rewrite a l
rw, Match
m)
(Pattern l
_ := VariablePattern ClassId
v, Match Subst
subst ClassId
eclass) -> do
case ClassId -> Subst -> Maybe ClassId
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
v Subst
subst of
Maybe ClassId
Nothing -> [Char] -> EGraphM a l ()
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst"
Just ClassId
n -> do
ClassId
_ <- ClassId -> ClassId -> EGraphM a l ClassId
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ClassId -> ClassId -> EGraphM anl l ClassId
merge ClassId
n ClassId
eclass
() -> EGraphM a l ()
forall a. a -> StateT (EGraph a l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Pattern l
_ := NonVariablePattern l (Pattern l)
rhs, Match Subst
subst ClassId
eclass) -> do
ClassId
eclass' <- Subst -> l (Pattern l) -> EGraphM a l ClassId
reprPat Subst
subst l (Pattern l)
rhs
ClassId
_ <- ClassId -> ClassId -> EGraphM a l ClassId
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ClassId -> ClassId -> EGraphM anl l ClassId
merge ClassId
eclass ClassId
eclass'
() -> EGraphM a l ()
forall a. a -> StateT (EGraph a l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
reprPat :: Subst -> l (Pattern l) -> EGraphM a l ClassId
reprPat :: Subst -> l (Pattern l) -> EGraphM a l ClassId
reprPat Subst
subst = ENode l -> EGraphM a l ClassId
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ENode l -> EGraphM anl l ClassId
add (ENode l -> EGraphM a l ClassId)
-> (l ClassId -> ENode l) -> l ClassId -> EGraphM a l ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node (l ClassId -> EGraphM a l ClassId)
-> (l (Pattern l) -> StateT (EGraph a l) Identity (l ClassId))
-> l (Pattern l)
-> EGraphM a l ClassId
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Pattern l -> EGraphM a l ClassId)
-> l (Pattern l) -> StateT (EGraph a l) Identity (l ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse \case
VariablePattern ClassId
v ->
case ClassId -> Subst -> Maybe ClassId
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
v Subst
subst of
Maybe ClassId
Nothing -> [Char] -> EGraphM a l ClassId
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst?"
Just ClassId
i -> ClassId -> EGraphM a l ClassId
forall a. a -> StateT (EGraph a l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i
NonVariablePattern l (Pattern l)
p -> Subst -> l (Pattern l) -> EGraphM a l ClassId
reprPat Subst
subst l (Pattern l)
p
{-# INLINEABLE runEqualitySaturation #-}