{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BlockArguments #-}
{-|
  Given an input program 𝑝, equality saturation constructs an e-graph 𝐸 that
  represents a large set of programs equivalent to 𝑝, and then extracts the
  “best” program from 𝐸.

  The e-graph is grown by repeatedly applying pattern-based rewrites.
  Critically, these rewrites only add information to the e-graph, eliminating
  the need for careful ordering.

  Upon reaching a fixed point (saturation), 𝐸 will represent all equivalent
  ways to express 𝑝 with respect to the given rewrites.

  After saturation (or timeout), a final extraction procedure analyzes 𝐸 and
  selects the optimal program according to a user-provided cost function.
 -}
module Data.Equality.Saturation
    (
      -- * Equality saturation
      equalitySaturation, equalitySaturation', runEqualitySaturation

      -- * Re-exports for equality saturation

      -- ** Writing rewrite rules
    , Rewrite(..), RewriteCondition

      -- ** Writing cost functions
      --
      -- | 'CostFunction' re-exported from 'Data.Equality.Extraction' since they are required to do equality saturation
    , CostFunction --, depthCost

      -- ** Writing expressions
      -- 
      -- | Expressions must be written in their fixed-point form, since the
      -- 'Language' must be given in its base functor form
    , Fix(..), cata

    ) where

import qualified Data.IntMap.Strict as IM

import Data.Bifunctor
import Control.Monad

import Data.Proxy

import Data.Equality.Utils
import Data.Equality.Graph.Nodes
import Data.Equality.Graph.Lens
import qualified Data.Equality.Graph as G
import Data.Equality.Graph.Monad
import Data.Equality.Language
import Data.Equality.Graph.Classes
import Data.Equality.Matching
import Data.Equality.Matching.Database
import Data.Equality.Extraction

import Data.Equality.Saturation.Rewrites
import Data.Equality.Saturation.Scheduler

-- | Equality saturation with defaults
equalitySaturation :: forall l cost
                    . (Language l, Ord cost)
                   => Fix l               -- ^ Expression to run equality saturation on
                   -> [Rewrite l]         -- ^ List of rewrite rules
                   -> CostFunction l cost -- ^ Cost function to extract the best equivalent representation
                   -> (Fix l, EGraph l)   -- ^ Best equivalent expression and resulting e-graph
equalitySaturation :: forall (l :: * -> *) cost.
(Language l, Ord cost) =>
Fix l -> [Rewrite l] -> CostFunction l cost -> (Fix l, EGraph l)
equalitySaturation = forall (l :: * -> *) schd cost.
(Language l, Scheduler schd, Ord cost) =>
Proxy schd
-> Fix l -> [Rewrite l] -> CostFunction l cost -> (Fix l, EGraph l)
equalitySaturation' (forall {k} (t :: k). Proxy t
Proxy @BackoffScheduler)


-- | Run equality saturation on an expression given a list of rewrites, and
-- extract the best equivalent expression according to the given cost function
--
-- This variant takes all arguments instead of using defaults
equalitySaturation' :: forall l schd cost
                    . (Language l, Scheduler schd, Ord cost)
                    => Proxy schd          -- ^ Proxy for the scheduler to use
                    -> Fix l               -- ^ Expression to run equality saturation on
                    -> [Rewrite l]         -- ^ List of rewrite rules
                    -> CostFunction l cost -- ^ Cost function to extract the best equivalent representation
                    -> (Fix l, EGraph l)   -- ^ Best equivalent expression and resulting e-graph
equalitySaturation' :: forall (l :: * -> *) schd cost.
(Language l, Scheduler schd, Ord cost) =>
Proxy schd
-> Fix l -> [Rewrite l] -> CostFunction l cost -> (Fix l, EGraph l)
equalitySaturation' Proxy schd
proxy Fix l
expr [Rewrite l]
rewrites CostFunction l cost
cost = forall (l :: * -> *) a. Language l => EGraphM l a -> (a, EGraph l)
egraph forall a b. (a -> b) -> a -> b
$ do

    -- Represent expression as an e-graph
    ClassId
origClass <- forall (l :: * -> *). Language l => Fix l -> EGraphM l ClassId
represent Fix l
expr

    -- Run equality saturation (by applying non-destructively all rewrites)
    forall (l :: * -> *) schd.
(Language l, Scheduler schd) =>
Proxy schd -> [Rewrite l] -> EGraphM l ()
runEqualitySaturation Proxy schd
proxy [Rewrite l]
rewrites

    -- Extract best solution from the e-class of the original expression
    forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets forall a b. (a -> b) -> a -> b
$ \EGraph l
g -> forall (lang :: * -> *) cost.
(Language lang, Ord cost) =>
EGraph lang -> CostFunction lang cost -> ClassId -> Fix lang
extractBest EGraph l
g CostFunction l cost
cost ClassId
origClass
{-# INLINABLE equalitySaturation' #-}


-- | Run equality saturation on an e-graph by non-destructively applying all
-- given rewrite rules until saturation (using the given 'Scheduler')
runEqualitySaturation :: forall l schd
                       . (Language l, Scheduler schd)
                      => Proxy schd          -- ^ Proxy for the scheduler to use
                      -> [Rewrite l]         -- ^ List of rewrite rules
                      -> EGraphM l ()
runEqualitySaturation :: forall (l :: * -> *) schd.
(Language l, Scheduler schd) =>
Proxy schd -> [Rewrite l] -> EGraphM l ()
runEqualitySaturation Proxy schd
_ [Rewrite l]
rewrites = ClassId -> IntMap (Stat schd) -> EGraphM l ()
runEqualitySaturation' ClassId
0 forall a. Monoid a => a
mempty where -- Start at iteration 0

  -- Take map each rewrite rule to stats on its usage so we can do
  -- backoff scheduling. Each rewrite rule is assigned an integer
  -- (corresponding to its position in the list of rewrite rules)
  runEqualitySaturation' :: Int -> IM.IntMap (Stat schd) -> EGraphM l ()
  runEqualitySaturation' :: ClassId -> IntMap (Stat schd) -> EGraphM l ()
runEqualitySaturation' ClassId
30 IntMap (Stat schd)
_ = forall (m :: * -> *) a. Monad m => a -> m a
return () -- Stop after X iterations
  runEqualitySaturation' ClassId
i IntMap (Stat schd)
stats = do

      EGraph l
egr <- forall (m :: * -> *) s. Monad m => StateT s m s
get

      let (NodeMap l ClassId
beforeMemo, ClassIdMap (EClass l)
beforeClasses) = (EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (NodeMap l ClassId)
_memo, EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes)
          db :: Database l
db = forall (l :: * -> *). Language l => EGraph l -> Database l
eGraphToDatabase EGraph l
egr

      -- Read-only phase, invariants are preserved
      -- With backoff scheduler
      -- ROMES:TODO parMap with chunks
      let (![(Rewrite l, Match)]
matches, IntMap (Stat schd)
newStats) = forall a. Monoid a => [a] -> a
mconcat (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats) (forall a b. [a] -> [b] -> [(a, b)]
zip [ClassId
1..] [Rewrite l]
rewrites))

      -- Write-only phase, temporarily break invariants
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Rewrite l, Match)]
matches (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs

      -- Restore the invariants once per iteration
      forall (l :: * -> *). Language l => EGraphM l ()
rebuild
      
      (NodeMap l ClassId
afterMemo, ClassIdMap (EClass l)
afterClasses) <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (\EGraph l
g -> (EGraph l
gforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (NodeMap l ClassId)
_memo, EGraph l
gforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (ClassIdMap (EClass l))
_classes))

      -- ROMES:TODO: Node limit...
      -- ROMES:TODO: Actual Timeout... not just iteration timeout
      -- ROMES:TODO Better saturation (see Runner)
      -- Apply rewrites until saturated or ROMES:TODO: timeout
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (l :: * -> *) a. NodeMap l a -> ClassId
G.sizeNM NodeMap l ClassId
afterMemo forall a. Eq a => a -> a -> Bool
== forall (l :: * -> *) a. NodeMap l a -> ClassId
G.sizeNM NodeMap l ClassId
beforeMemo
                Bool -> Bool -> Bool
&& forall a. IntMap a -> ClassId
IM.size ClassIdMap (EClass l)
afterClasses forall a. Eq a => a -> a -> Bool
== forall a. IntMap a -> ClassId
IM.size ClassIdMap (EClass l)
beforeClasses)
          (ClassId -> IntMap (Stat schd) -> EGraphM l ()
runEqualitySaturation' (ClassId
iforall a. Num a => a -> a -> a
+ClassId
1) IntMap (Stat schd)
newStats)

  matchWithScheduler :: Database l -> Int -> IM.IntMap (Stat schd) -> (Int, Rewrite l) -> ([(Rewrite l, Match)], IM.IntMap (Stat schd))
  matchWithScheduler :: Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats = \case
      (ClassId
rw_id, Rewrite l
rw :| RewriteCondition l
cnd) -> forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall (lang :: * -> *).
Rewrite lang -> RewriteCondition lang -> Rewrite lang
:| RewriteCondition l
cnd))) forall a b. (a -> b) -> a -> b
$ Database l
-> ClassId
-> IntMap (Stat schd)
-> (ClassId, Rewrite l)
-> ([(Rewrite l, Match)], IntMap (Stat schd))
matchWithScheduler Database l
db ClassId
i IntMap (Stat schd)
stats (ClassId
rw_id, Rewrite l
rw)
      (ClassId
rw_id, Pattern l
lhs := Pattern l
rhs) -> do
          case forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
rw_id IntMap (Stat schd)
stats of
            -- If it's banned until some iteration, don't match this rule
            -- against anything.
            Just Stat schd
s | forall s. Scheduler s => ClassId -> Stat s -> Bool
isBanned @schd ClassId
i Stat schd
s -> ([], IntMap (Stat schd)
stats)

            -- Otherwise, match and update stats
            Maybe (Stat schd)
x -> do

                -- Match pattern
                let matches' :: [Match]
matches' = forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
lhs -- Add rewrite to the e-match substitutions

                -- Backoff scheduler: update stats
                let newStats :: IntMap (Stat schd)
newStats = forall s.
Scheduler s =>
ClassId
-> ClassId
-> Maybe (Stat s)
-> IntMap (Stat s)
-> [Match]
-> IntMap (Stat s)
updateStats @schd ClassId
i ClassId
rw_id Maybe (Stat schd)
x IntMap (Stat schd)
stats [Match]
matches'

                (forall a b. (a -> b) -> [a] -> [b]
map (Pattern l
lhs forall (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite lang
:= Pattern l
rhs,) [Match]
matches', IntMap (Stat schd)
newStats)

  applyMatchesRhs :: (Rewrite l, Match) -> EGraphM l ()
  applyMatchesRhs :: (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs =
      \case
          (Rewrite l
rw :| RewriteCondition l
cond, m :: Match
m@(Match Subst
subst ClassId
_)) -> do
              -- If the rewrite condition is satisfied, applyMatchesRhs on the rewrite rule.
              EGraph l
egr <- forall (m :: * -> *) s. Monad m => StateT s m s
get
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RewriteCondition l
cond Subst
subst EGraph l
egr) forall a b. (a -> b) -> a -> b
$
                 (Rewrite l, Match) -> EGraphM l ()
applyMatchesRhs (Rewrite l
rw, Match
m)

          (Pattern l
_ := VariablePattern ClassId
v, Match Subst
subst ClassId
eclass) -> do
              -- rhs is equal to a variable, simply merge class where lhs
              -- pattern was found (@eclass@) and the eclass the pattern
              -- variable matched (@lookup v subst@)
              case forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
v Subst
subst of
                Maybe ClassId
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst"
                Just ClassId
n  -> do
                    ClassId
_ <- forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraphM l ClassId
merge ClassId
n ClassId
eclass
                    forall (m :: * -> *) a. Monad m => a -> m a
return ()

          (Pattern l
_ := NonVariablePattern l (Pattern l)
rhs, Match Subst
subst ClassId
eclass) -> do
              -- rhs is (at the top level) a non-variable pattern, so substitute
              -- all pattern variables in the pattern and create a new e-node (and
              -- e-class that represents it), then merge the e-class of the
              -- substituted rhs with the class that matched the left hand side
              ClassId
eclass' <- Subst -> l (Pattern l) -> EGraphM l ClassId
reprPat Subst
subst l (Pattern l)
rhs
              ClassId
_ <- forall (l :: * -> *).
Language l =>
ClassId -> ClassId -> EGraphM l ClassId
merge ClassId
eclass ClassId
eclass'
              forall (m :: * -> *) a. Monad m => a -> m a
return ()

  -- | Represent a pattern in the e-graph a pattern given substitions
  reprPat :: Subst -> l (Pattern l) -> EGraphM l ClassId
  reprPat :: Subst -> l (Pattern l) -> EGraphM l ClassId
reprPat Subst
subst = forall (l :: * -> *). Language l => ENode l -> EGraphM l ClassId
add forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (l :: * -> *). l ClassId -> ENode l
Node forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse \case
      VariablePattern ClassId
v ->
          case forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
v Subst
subst of
              Maybe ClassId
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: couldn't find v in subst?"
              Just ClassId
i  -> forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i
      NonVariablePattern l (Pattern l)
p -> Subst -> l (Pattern l) -> EGraphM l ClassId
reprPat Subst
subst l (Pattern l)
p

{-# INLINEABLE runEqualitySaturation #-}