{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BlockArguments #-}
module Data.Equality.Saturation
(
equalitySaturation, equalitySaturation'
, Rewrite(..), RewriteCondition
, CostFunction
, Fix(..), cata
) where
import qualified Data.IntMap.Strict as IM
import Data.Bifunctor
import Control.Monad
import Data.Proxy
import Data.Equality.Utils
import qualified Data.Equality.Graph as G
import Data.Equality.Graph.Monad
import Data.Equality.Language
import Data.Equality.Graph.Classes
import Data.Equality.Matching
import Data.Equality.Matching.Database
import Data.Equality.Extraction
import Data.Equality.Saturation.Rewrites
import Data.Equality.Saturation.Scheduler
equalitySaturation :: forall l. Language l
=> Fix l
-> [Rewrite l]
-> CostFunction l
-> (Fix l, EGraph l)
equalitySaturation :: forall (l :: * -> *).
Language l =>
Fix l -> [Rewrite l] -> CostFunction l -> (Fix l, EGraph l)
equalitySaturation = Proxy BackoffScheduler
-> Fix l -> [Rewrite l] -> CostFunction l -> (Fix l, EGraph l)
forall (l :: * -> *) schd.
(Language l, Scheduler schd) =>
Proxy schd
-> Fix l -> [Rewrite l] -> CostFunction l -> (Fix l, EGraph l)
equalitySaturation' (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @BackoffScheduler)
equalitySaturation' :: forall l schd
. (Language l, Scheduler schd)
=> Proxy schd
-> Fix l
-> [Rewrite l]
-> CostFunction l
-> (Fix l, EGraph l)
equalitySaturation' :: forall (l :: * -> *) schd.
(Language l, Scheduler schd) =>
Proxy schd
-> Fix l -> [Rewrite l] -> CostFunction l -> (Fix l, EGraph l)
equalitySaturation' Proxy schd
_ Fix l
expr [Rewrite l]
rewrites CostFunction l
cost = EGraphM l (Fix l) -> (Fix l, EGraph l)
forall (l :: * -> *) a. Language l => EGraphM l a -> (a, EGraph l)
egraph (EGraphM l (Fix l) -> (Fix l, EGraph l))
-> EGraphM l (Fix l) -> (Fix l, EGraph l)
forall a b. (a -> b) -> a -> b
$ do
Int
origClass <- Fix l -> EGraphM l Int
forall (l :: * -> *). Language l => Fix l -> EGraphM l Int
represent Fix l
expr
Int -> IntMap (Stat schd) -> EGraphM l ()
equalitySaturation'' Int
0 IntMap (Stat schd)
forall a. Monoid a => a
mempty
(EGraph l -> Fix l) -> EGraphM l (Fix l)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((EGraph l -> Fix l) -> EGraphM l (Fix l))
-> (EGraph l -> Fix l) -> EGraphM l (Fix l)
forall a b. (a -> b) -> a -> b
$ \EGraph l
g -> EGraph l -> CostFunction l -> Int -> Fix l
forall (lang :: * -> *).
Language lang =>
EGraph lang -> CostFunction lang -> Int -> Fix lang
extractBest EGraph l
g CostFunction l
cost Int
origClass
where
equalitySaturation'' :: Int -> IM.IntMap (Stat schd) -> EGraphM l ()
equalitySaturation'' :: Int -> IntMap (Stat schd) -> EGraphM l ()
equalitySaturation'' Int
30 IntMap (Stat schd)
_ = () -> EGraphM l ()
forall a. a -> StateT (EGraph l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
equalitySaturation'' Int
i IntMap (Stat schd)
stats = do
egr :: EGraph l
egr@G.EGraph{ memo :: forall (l :: * -> *). EGraph l -> Memo l
G.memo = Memo l
beforeMemo, classes :: forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
G.classes = ClassIdMap (EClass l)
beforeClasses } <- StateT (EGraph l) Identity (EGraph l)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let db :: Database l
db = EGraph l -> Database l
forall (l :: * -> *). Language l => EGraph l -> Database l
eGraphToDatabase EGraph l
egr
let (![(Rewrite l, Match)]
matches, IntMap (Stat schd)
newStats) = [([(Rewrite l, Match)], IntMap (Stat schd))]
-> ([(Rewrite l, Match)], IntMap (Stat schd))
forall a. Monoid a => [a] -> a
mconcat (((Int, Rewrite l) -> ([(Rewrite l, Match)], IntMap (Stat schd)))
-> [(Int, Rewrite l)]
-> [([(Rewrite l, Match)], IntMap (Stat schd))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Database l
-> Int
-> IntMap (Stat schd)
-> (Int, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db Int
i IntMap (Stat schd)
stats) ([Int] -> [Rewrite l] -> [(Int, Rewrite l)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [Rewrite l]
rewrites))
[(Rewrite l, Match)]
-> ((Rewrite l, Match) -> EGraphM l ()) -> EGraphM l ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Rewrite l, Match)]
matches (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs
EGraphM l ()
forall (l :: * -> *). Language l => EGraphM l ()
rebuild
G.EGraph { memo :: forall (l :: * -> *). EGraph l -> Memo l
G.memo = Memo l
afterMemo, classes :: forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
G.classes = ClassIdMap (EClass l)
afterClasses } <- StateT (EGraph l) Identity (EGraph l)
forall (m :: * -> *) s. Monad m => StateT s m s
get
Bool -> EGraphM l () -> EGraphM l ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Memo l -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
G.sizeNM Memo l
afterMemo Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Memo l -> Int
forall (l :: * -> *) a. NodeMap l a -> Int
G.sizeNM Memo l
beforeMemo
Bool -> Bool -> Bool
&& ClassIdMap (EClass l) -> Int
forall a. IntMap a -> Int
IM.size ClassIdMap (EClass l)
afterClasses Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdMap (EClass l) -> Int
forall a. IntMap a -> Int
IM.size ClassIdMap (EClass l)
beforeClasses)
(Int -> IntMap (Stat schd) -> EGraphM l ()
equalitySaturation'' (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) IntMap (Stat schd)
newStats)
matchWithScheduler :: Database l -> Int -> IM.IntMap (Stat schd) -> (Int, Rewrite l) -> ([(Rewrite l, Match)], IM.IntMap (Stat schd))
matchWithScheduler :: Database l
-> Int
-> IntMap (Stat schd)
-> (Int, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db Int
i IntMap (Stat schd)
stats = \case
(Int
rw_id, Rewrite l
rw :| RewriteCondition l
cnd) -> ([(Rewrite l, Match)] -> [(Rewrite l, Match)])
-> ([(Rewrite l, Match)], IntMap (Stat schd))
-> ([(Rewrite l, Match)], IntMap (Stat schd))
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (((Rewrite l, Match) -> (Rewrite l, Match))
-> [(Rewrite l, Match)] -> [(Rewrite l, Match)]
forall a b. (a -> b) -> [a] -> [b]
map ((Rewrite l -> Rewrite l)
-> (Rewrite l, Match) -> (Rewrite l, Match)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Rewrite l -> RewriteCondition l -> Rewrite l
forall (lang :: * -> *).
Rewrite lang -> RewriteCondition lang -> Rewrite lang
:| RewriteCondition l
cnd))) (([(Rewrite l, Match)], IntMap (Stat schd))
-> ([(Rewrite l, Match)], IntMap (Stat schd)))
-> ([(Rewrite l, Match)], IntMap (Stat schd))
-> ([(Rewrite l, Match)], IntMap (Stat schd))
forall a b. (a -> b) -> a -> b
$ Database l
-> Int
-> IntMap (Stat schd)
-> (Int, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db Int
i IntMap (Stat schd)
stats (Int
rw_id, Rewrite l
rw)
(Int
rw_id, Pattern l
lhs := Pattern l
rhs) -> do
case Int -> IntMap (Stat schd) -> Maybe (Stat schd)
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
rw_id IntMap (Stat schd)
stats of
Just Stat schd
s | forall s. Scheduler s => Int -> Stat s -> Bool
isBanned @schd Int
i Stat schd
s -> ([], IntMap (Stat schd)
stats)
Maybe (Stat schd)
x -> do
let matches' :: [Match]
matches' = Database l -> Pattern l -> [Match]
forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
lhs
let newStats :: IntMap (Stat schd)
newStats = forall s.
Scheduler s =>
Int
-> Int
-> Maybe (Stat s)
-> IntMap (Stat s)
-> [Match]
-> IntMap (Stat s)
updateStats @schd Int
i Int
rw_id Maybe (Stat schd)
x IntMap (Stat schd)
stats [Match]
matches'
((Match -> (Rewrite l, Match)) -> [Match] -> [(Rewrite l, Match)]
forall a b. (a -> b) -> [a] -> [b]
map (Pattern l
lhs Pattern l -> Pattern l -> Rewrite l
forall (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite lang
:= Pattern l
rhs,) [Match]
matches', IntMap (Stat schd)
newStats)
applyMatchesRhs :: (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs :: (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs =
\case
(Rewrite l
rw :| RewriteCondition l
cond, m :: Match
m@(Match Subst
subst Int
_)) -> do
EGraph l
egr <- StateT (EGraph l) Identity (EGraph l)
forall (m :: * -> *) s. Monad m => StateT s m s
get
Bool -> EGraphM l () -> EGraphM l ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RewriteCondition l
cond Subst
subst EGraph l
egr) (EGraphM l () -> EGraphM l ()) -> EGraphM l () -> EGraphM l ()
forall a b. (a -> b) -> a -> b
$
(Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs (Rewrite l
rw, Match
m)
(Pattern l
_ := VariablePattern Int
v, Match Subst
subst Int
eclass) -> do
case Int -> Subst -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
v Subst
subst of
Maybe Int
Nothing -> [Char] -> EGraphM l ()
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst"
Just Int
n -> do
Int
_ <- Int -> Int -> EGraphM l Int
forall (l :: * -> *). Language l => Int -> Int -> EGraphM l Int
merge Int
n Int
eclass
() -> EGraphM l ()
forall a. a -> StateT (EGraph l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Pattern l
_ := NonVariablePattern l (Pattern l)
rhs, Match Subst
subst Int
eclass) -> do
Int
eclass' <- Subst -> l (Pattern l) -> EGraphM l Int
reprPat Subst
subst l (Pattern l)
rhs
Int
_ <- Int -> Int -> EGraphM l Int
forall (l :: * -> *). Language l => Int -> Int -> EGraphM l Int
merge Int
eclass Int
eclass'
() -> EGraphM l ()
forall a. a -> StateT (EGraph l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
reprPat :: Subst -> l (Pattern l) -> EGraphM l ClassId
reprPat :: Subst -> l (Pattern l) -> EGraphM l Int
reprPat Subst
subst = ENode l -> EGraphM l Int
forall (l :: * -> *). Language l => ENode l -> EGraphM l Int
add (ENode l -> EGraphM l Int)
-> (l Int -> ENode l) -> l Int -> EGraphM l Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l Int -> ENode l
forall (l :: * -> *). l Int -> ENode l
G.Node (l Int -> EGraphM l Int)
-> (l (Pattern l) -> StateT (EGraph l) Identity (l Int))
-> l (Pattern l)
-> EGraphM l Int
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Pattern l -> EGraphM l Int)
-> l (Pattern l) -> StateT (EGraph l) Identity (l Int)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> l a -> f (l b)
traverse \case
VariablePattern Int
v ->
case Int -> Subst -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
v Subst
subst of
Maybe Int
Nothing -> [Char] -> EGraphM l Int
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst?"
Just Int
i -> Int -> EGraphM l Int
forall a. a -> StateT (EGraph l) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
NonVariablePattern l (Pattern l)
p -> Subst -> l (Pattern l) -> EGraphM l Int
reprPat Subst
subst l (Pattern l)
p
{-# SCC equalitySaturation' #-}