{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-|
   Equality-matching, implemented using a relational database
   (defined in 'Data.Equality.Matching.Database') according to the paper
   \"Relational E-Matching\" https://arxiv.org/abs/2108.02290.
 -}
module Data.Equality.Matching
    ( ematch
    , eGraphToDatabase
    , Match(..)
    , compileToQuery

    , module Data.Equality.Matching.Pattern
    )
    where

import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils

import Control.Monad
import Control.Monad.Trans.State.Strict

import qualified Data.Map.Strict    as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS

import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern

-- | Matching a pattern on an e-graph returns the e-class in which the pattern
-- was matched and an e-class substitution for every 'VariablePattern' in the pattern.
data Match = Match
    { Match -> Subst
matchSubst :: !Subst
    , Match -> ClassId
matchClassId :: {-# UNPACK #-} !ClassId
    }

-- TODO: Perhaps e-graph could carry database and rebuild it on rebuild

-- | Match a pattern against a 'Database', which can be gotten from an 'EGraph' with 'eGraphToDatabase'
--
-- Returns a list of matches, one 'Match' for each set of valid substitutions
-- for all variables and the equivalence class in which the pattern was matched.
--
-- 'ematch' takes a 'Database' instead of an 'EGraph' because the 'Database'
-- could be constructed only once and shared accross matching.
ematch :: Language l
       => Database l
       -> Pattern l
       -> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
patr =
    let
        (Query l
q, ClassId
root) = forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery Pattern l
patr

        -- | Convert each substitution into a match by getting the class-id
        -- where we matched from the subst
        --
        -- If the substitution is empty there is no match
        f :: Subst -> Maybe Match
        f :: Subst -> Maybe Match
f Subst
s = if forall a. IntMap a -> Bool
IM.null Subst
s then forall a. Maybe a
Nothing
                           else case forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
root Subst
s of
                                  Maybe ClassId
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"how is root not in map?"
                                  Just ClassId
found -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> ClassId -> Match
Match Subst
s ClassId
found)

     in forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)

-- | Convert an e-graph into a database
eGraphToDatabase :: Language l => EGraph l -> Database l
eGraphToDatabase :: forall (l :: * -> *). Language l => EGraph l -> Database l
eGraphToDatabase EGraph l
egr = forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB (forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB forall a. Monoid a => a
mempty) (EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (NodeMap l ClassId)
_memo)
  where

    -- Add an enode in an e-graph, given its class, to a database
    addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
    addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB ENode l
enode ClassId
classid (DB Map (Operator l) IntTrie
m) =
        -- ROMES:TODO map find
        -- Insert or create a relation R_f(i1,i2,...,in) for lang in which 
        forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB forall a b. (a -> b) -> a -> b
$ forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate (ClassId
classidforall a. a -> [a] -> [a]
:forall (l :: * -> *). Traversable l => ENode l -> [ClassId]
children ENode l
enode)) (forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m

    -- Populate or create a triemap given the population D_x (ClassIds)
    -- Insert remaining ids population doesn't exist, recursively merge tries with remaining ids
    populate :: [ClassId] -> Maybe IntTrie -> IntTrie
    -- If trie map entry doesn't exist yet, populate an empty map with the remaining ids
    populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate []     Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
    populate (ClassId
x:[ClassId]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId -> IntSet
IS.singleton ClassId
x) forall a b. (a -> b) -> a -> b
$ forall a. ClassId -> a -> IntMap a
IM.singleton ClassId
x ([ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs forall a. Maybe a
Nothing)
    -- If trie map entry already exists, populate the existing map with the remaining ids
    populate []     (Just IntTrie
it)              = IntTrie
it
    populate (ClassId
x:[ClassId]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId
x ClassId -> IntSet -> IntSet
`IS.insert` IntSet
k) forall a b. (a -> b) -> a -> b
$ forall a. (Maybe a -> Maybe a) -> ClassId -> IntMap a -> IntMap a
IM.alter (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs) ClassId
x IntMap IntTrie
m
{-# INLINABLE eGraphToDatabase #-}


-- * Database related internals

-- | Auxiliary result in 'compileToQuery' algorithm
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]

-- | Compiles a 'Pattern' to a 'Query' and returns the query root variable with
-- it.
-- The root variable's substitutions are the e-classes where the pattern
-- matched
compileToQuery :: (Traversable lang) => Pattern lang -> (Query lang, Var)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery (VariablePattern ClassId
x) = (forall (lang :: * -> *). ClassId -> Query lang
SelectAllQuery ClassId
x, ClassId
x)
compileToQuery pa :: Pattern lang
pa@(NonVariablePattern lang (Pattern lang)
_) =

  let ClassId
root :~ [Atom lang]
atoms = forall s a. State s a -> s -> a
evalState (forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux Pattern lang
pa) ClassId
0
   in (forall (lang :: * -> *). [ClassId] -> [Atom lang] -> Query lang
Query ([ClassId] -> [ClassId]
nubInt forall a b. (a -> b) -> a -> b
$ ClassId
rootforall a. a -> [a] -> [a]
:forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars Pattern lang
pa) [Atom lang]
atoms, ClassId
root)

    where

        aux :: (Traversable lang) => Pattern lang -> State Int (AuxResult lang)
        aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux (VariablePattern ClassId
x) = forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
x forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ []) -- from definition in relational e-matching paper (needed for as base case for recursion)
        aux (NonVariablePattern lang (Pattern lang)
p) = do
            ClassId
v <- forall (m :: * -> *) s. Monad m => StateT s m s
get
            forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (forall a. Num a => a -> a -> a
+ClassId
1)
            (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux lang (Pattern lang)
p
            let boundVars :: [ClassId]
boundVars = forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
b :~ [Atom lang]
_) -> ClassId
b) [AuxResult lang]
auxs
                atoms :: [Atom lang]
atoms     = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
                -- Number of bound vars should match number of children of this
                -- lang. We can traverse the pattern and replace sub-patterns with
                -- their corresponding bound variable
                p' :: lang ClassId
p' = forall s a. State s a -> s -> a
evalState (forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p [ClassId]
boundVars) ClassId
0
            forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
v forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ (forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (ClassId -> ClassIdOrVar
CVar ClassId
v) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ClassId -> ClassIdOrVar
CVar lang ClassId
p')forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
                where
                    -- State keeps track of the index of the variable we're
                    -- taking from the bound vars array
                    subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
                    subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p' [ClassId]
boundVars = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ ([ClassId]
boundVars forall a. [a] -> ClassId -> a
!!) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ClassId
i -> forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (forall a. Num a => a -> a -> a
+ClassId
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i)) lang (Pattern lang)
p'

        -- | Return distinct variables in a pattern
        vars :: Foldable lang => Pattern lang -> [Var]
        vars :: forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars (VariablePattern ClassId
x) = [ClassId
x]
        vars (NonVariablePattern lang (Pattern lang)
p) = [ClassId] -> [ClassId]
nubInt forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList lang (Pattern lang)
p
{-# INLINABLE compileToQuery #-}