{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-|
   Equality-matching, implemented using a relational database
   (defined in 'Data.Equality.Matching.Database') according to the paper
   \"Relational E-Matching\" https://arxiv.org/abs/2108.02290.
 -}
module Data.Equality.Matching
    ( ematch
    , eGraphToDatabase
    , Match(..)
    , compileToQuery

    , module Data.Equality.Matching.Pattern
    )
    where

import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils

import Control.Monad
import Control.Monad.Trans.State.Strict

import qualified Data.Map.Strict    as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS

import Data.Equality.Graph
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern

-- | Matching a pattern on an e-graph returns the e-class in which the pattern
-- was matched and an e-class substitution for every 'VariablePattern' in the pattern.
data Match = Match
    { Match -> Subst
matchSubst :: !Subst
    , Match -> ClassId
matchClassId :: {-# UNPACK #-} !ClassId
    }

-- TODO: Perhaps e-graph could carry database and rebuild it on rebuild

-- | Match a pattern against a 'Database', which can be gotten from an 'EGraph' with 'eGraphToDatabase'
--
-- Returns a list of matches, one 'Match' for each set of valid substitutions
-- for all variables and the equivalence class in which the pattern was matched.
--
-- 'ematch' takes a 'Database' instead of an 'EGraph' because the 'Database'
-- could be constructed only once and shared accross matching.
ematch :: Language l
       => Database l
       -> Pattern l
       -> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
patr =
    let
        (Query l
q, ClassId
root) = Pattern l -> (Query l, ClassId)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery Pattern l
patr

        -- | Convert each substitution into a match by getting the class-id
        -- where we matched from the subst
        --
        -- If the substitution is empty there is no match
        f :: Subst -> Maybe Match
        f :: Subst -> Maybe Match
f Subst
s = if Subst -> Bool
forall a. IntMap a -> Bool
IM.null Subst
s then Maybe Match
forall a. Maybe a
Nothing
                           else case ClassId -> Subst -> Maybe ClassId
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
root Subst
s of
                                  Maybe ClassId
Nothing -> [Char] -> Maybe Match
forall a. HasCallStack => [Char] -> a
error [Char]
"how is root not in map?"
                                  Just ClassId
found -> Match -> Maybe Match
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> ClassId -> Match
Match Subst
s ClassId
found)

     in (Subst -> Maybe Match) -> [Subst] -> [Match]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (Database l -> Query l -> [Subst]
forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)

-- | Convert an e-graph into a database
eGraphToDatabase :: Language l => EGraph l -> Database l
eGraphToDatabase :: forall (l :: * -> *). Language l => EGraph l -> Database l
eGraphToDatabase EGraph{ClassIdMap (EClass l)
Worklist l
ReprUnionFind
unionFind :: ReprUnionFind
classes :: ClassIdMap (EClass l)
memo :: Worklist l
worklist :: Worklist l
analysisWorklist :: Worklist l
unionFind :: forall (l :: * -> *). EGraph l -> ReprUnionFind
classes :: forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
memo :: forall (l :: * -> *). EGraph l -> Memo l
worklist :: forall (l :: * -> *). EGraph l -> Memo l
analysisWorklist :: forall (l :: * -> *). EGraph l -> Memo l
..} = (ENode l -> ClassId -> Database l -> Database l)
-> Database l -> Worklist l -> Database l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> ClassId -> Database l -> Database l
forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB (Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB Map (Operator l) IntTrie
forall a. Monoid a => a
mempty) Worklist l
memo
  where

    -- Add an enode in an e-graph, given its class, to a database
    addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
    addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB ENode l
enode ClassId
classid (DB Map (Operator l) IntTrie
m) =
        -- ROMES:TODO map find
        -- Insert or create a relation R_f(i1,i2,...,in) for lang in which 
        Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB (Map (Operator l) IntTrie -> Database l)
-> Map (Operator l) IntTrie -> Database l
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Operator l
-> Map (Operator l) IntTrie
-> Map (Operator l) IntTrie
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate (ClassId
classidClassId -> [ClassId] -> [ClassId]
forall a. a -> [a] -> [a]
:ENode l -> [ClassId]
forall (l :: * -> *). Traversable l => ENode l -> [ClassId]
children ENode l
enode)) (ENode l -> Operator l
forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m
    {-# SCC addENodeToDB #-}

    -- Populate or create a triemap given the population D_x (ClassIds)
    -- Insert remaining ids population doesn't exist, recursively merge tries with remaining ids
    populate :: [ClassId] -> Maybe IntTrie -> IntTrie
    -- If trie map entry doesn't exist yet, populate an empty map with the remaining ids
    populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate []     Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
forall a. Monoid a => a
mempty
    populate (ClassId
x:[ClassId]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId -> IntSet
IS.singleton ClassId
x) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ ClassId -> IntTrie -> IntMap IntTrie
forall a. ClassId -> a -> IntMap a
IM.singleton ClassId
x ([ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs Maybe IntTrie
forall a. Maybe a
Nothing)
    -- If trie map entry already exists, populate the existing map with the remaining ids
    populate []     (Just IntTrie
it)              = IntTrie
it
    populate (ClassId
x:[ClassId]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId
x ClassId -> IntSet -> IntSet
`IS.insert` IntSet
k) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> ClassId -> IntMap IntTrie -> IntMap IntTrie
forall a. (Maybe a -> Maybe a) -> ClassId -> IntMap a -> IntMap a
IM.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs) ClassId
x IntMap IntTrie
m
    {-# SCC populate #-}
{-# SCC eGraphToDatabase #-}


-- * Database related internals

-- | Auxiliary result in 'compileToQuery' algorithm
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]

-- | Compiles a 'Pattern' to a 'Query' and returns the query root variable with
-- it.
-- The root variable's substitutions are the e-classes where the pattern
-- matched
compileToQuery :: (Traversable lang) => Pattern lang -> (Query lang, Var)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery (VariablePattern ClassId
x) = (ClassId -> Query lang
forall (lang :: * -> *). ClassId -> Query lang
SelectAllQuery ClassId
x, ClassId
x)
compileToQuery pa :: Pattern lang
pa@(NonVariablePattern lang (Pattern lang)
_) =

  let ClassId
root :~ [Atom lang]
atoms = State ClassId (AuxResult lang) -> ClassId -> AuxResult lang
forall s a. State s a -> s -> a
evalState (Pattern lang -> State ClassId (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux Pattern lang
pa) ClassId
0
   in ([ClassId] -> [Atom lang] -> Query lang
forall (lang :: * -> *). [ClassId] -> [Atom lang] -> Query lang
Query ([ClassId] -> [ClassId]
nubInt ([ClassId] -> [ClassId]) -> [ClassId] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ ClassId
rootClassId -> [ClassId] -> [ClassId]
forall a. a -> [a] -> [a]
:Pattern lang -> [ClassId]
forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars Pattern lang
pa) [Atom lang]
atoms, ClassId
root)

    where

        aux :: (Traversable lang) => Pattern lang -> State Int (AuxResult lang)
        aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux (VariablePattern ClassId
x) = AuxResult lang -> StateT ClassId Identity (AuxResult lang)
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
x ClassId -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ []) -- from definition in relational e-matching paper (needed for as base case for recursion)
        aux (NonVariablePattern lang (Pattern lang)
p) = do
            ClassId
v <- StateT ClassId Identity ClassId
forall (m :: * -> *) s. Monad m => StateT s m s
get
            (ClassId -> ClassId) -> StateT ClassId Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1)
            (lang (AuxResult lang) -> [AuxResult lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- (Pattern lang -> StateT ClassId Identity (AuxResult lang))
-> lang (Pattern lang)
-> StateT ClassId Identity (lang (AuxResult lang))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse Pattern lang -> StateT ClassId Identity (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux lang (Pattern lang)
p
            let boundVars :: [ClassId]
boundVars = (AuxResult lang -> ClassId) -> [AuxResult lang] -> [ClassId]
forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
b :~ [Atom lang]
_) -> ClassId
b) [AuxResult lang]
auxs
                atoms :: [Atom lang]
atoms     = [[Atom lang]] -> [Atom lang]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Atom lang]] -> [Atom lang]) -> [[Atom lang]] -> [Atom lang]
forall a b. (a -> b) -> a -> b
$ (AuxResult lang -> [Atom lang])
-> [AuxResult lang] -> [[Atom lang]]
forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
                -- Number of bound vars should match number of children of this
                -- lang. We can traverse the pattern and replace sub-patterns with
                -- their corresponding bound variable
                p' :: lang ClassId
p' = State ClassId (lang ClassId) -> ClassId -> lang ClassId
forall s a. State s a -> s -> a
evalState (lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p [ClassId]
boundVars) ClassId
0
            AuxResult lang -> StateT ClassId Identity (AuxResult lang)
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
v ClassId -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ (ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (ClassId -> ClassIdOrVar
CVar ClassId
v) ((ClassId -> ClassIdOrVar) -> lang ClassId -> lang ClassIdOrVar
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ClassId -> ClassIdOrVar
CVar lang ClassId
p')Atom lang -> [Atom lang] -> [Atom lang]
forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
                where
                    -- State keeps track of the index of the variable we're
                    -- taking from the bound vars array
                    subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
                    subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p' [ClassId]
boundVars = (Pattern lang -> StateT ClassId Identity ClassId)
-> lang (Pattern lang) -> StateT ClassId Identity (lang ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse (StateT ClassId Identity ClassId
-> Pattern lang -> StateT ClassId Identity ClassId
forall a b. a -> b -> a
const (StateT ClassId Identity ClassId
 -> Pattern lang -> StateT ClassId Identity ClassId)
-> StateT ClassId Identity ClassId
-> Pattern lang
-> StateT ClassId Identity ClassId
forall a b. (a -> b) -> a -> b
$ ([ClassId]
boundVars [ClassId] -> ClassId -> ClassId
forall a. HasCallStack => [a] -> ClassId -> a
!!) (ClassId -> ClassId)
-> StateT ClassId Identity ClassId
-> StateT ClassId Identity ClassId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StateT ClassId Identity ClassId
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT ClassId Identity ClassId
-> (ClassId -> StateT ClassId Identity ClassId)
-> StateT ClassId Identity ClassId
forall a b.
StateT ClassId Identity a
-> (a -> StateT ClassId Identity b) -> StateT ClassId Identity b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ClassId
i -> (ClassId -> ClassId) -> StateT ClassId Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1) StateT ClassId Identity ()
-> StateT ClassId Identity ClassId
-> StateT ClassId Identity ClassId
forall a b.
StateT ClassId Identity a
-> StateT ClassId Identity b -> StateT ClassId Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ClassId -> StateT ClassId Identity ClassId
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i)) lang (Pattern lang)
p'

        -- | Return distinct variables in a pattern
        vars :: Foldable lang => Pattern lang -> [Var]
        vars :: forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars (VariablePattern ClassId
x) = [ClassId
x]
        vars (NonVariablePattern lang (Pattern lang)
p) = [ClassId] -> [ClassId]
nubInt ([ClassId] -> [ClassId]) -> [ClassId] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ [[ClassId]] -> [ClassId]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[ClassId]] -> [ClassId]) -> [[ClassId]] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ (Pattern lang -> [ClassId]) -> [Pattern lang] -> [[ClassId]]
forall a b. (a -> b) -> [a] -> [b]
map Pattern lang -> [ClassId]
forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars ([Pattern lang] -> [[ClassId]]) -> [Pattern lang] -> [[ClassId]]
forall a b. (a -> b) -> a -> b
$ lang (Pattern lang) -> [Pattern lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList lang (Pattern lang)
p
{-# SCC compileToQuery #-}