{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Equality.Matching
( ematch
, eGraphToDatabase
, Match(..)
, compileToQuery
, module Data.Equality.Matching.Pattern
)
where
import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils
import Control.Monad
import Control.Monad.Trans.State.Strict
import qualified Data.Map.Strict as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern
data Match = Match
{ Match -> Subst
matchSubst :: !Subst
, Match -> ClassId
matchClassId :: {-# UNPACK #-} !ClassId
}
ematch :: Language l
=> Database l
-> Pattern l
-> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
patr =
let
(Query l
q, ClassId
root) = forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery Pattern l
patr
f :: Subst -> Maybe Match
f :: Subst -> Maybe Match
f Subst
s = if forall a. IntMap a -> Bool
IM.null Subst
s then forall a. Maybe a
Nothing
else case forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
root Subst
s of
Maybe ClassId
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"how is root not in map?"
Just ClassId
found -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> ClassId -> Match
Match Subst
s ClassId
found)
in forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)
eGraphToDatabase :: Language l => EGraph l -> Database l
eGraphToDatabase :: forall (l :: * -> *). Language l => EGraph l -> Database l
eGraphToDatabase EGraph l
egr = forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB (forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB forall a. Monoid a => a
mempty) (EGraph l
egrforall s a. s -> Lens' s a -> a
^.forall (l :: * -> *). Lens' (EGraph l) (NodeMap l ClassId)
_memo)
where
addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB ENode l
enode ClassId
classid (DB Map (Operator l) IntTrie
m) =
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB forall a b. (a -> b) -> a -> b
$ forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate (ClassId
classidforall a. a -> [a] -> [a]
:forall (l :: * -> *). Traversable l => ENode l -> [ClassId]
children ENode l
enode)) (forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m
populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate [] Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
populate (ClassId
x:[ClassId]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId -> IntSet
IS.singleton ClassId
x) forall a b. (a -> b) -> a -> b
$ forall a. ClassId -> a -> IntMap a
IM.singleton ClassId
x ([ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs forall a. Maybe a
Nothing)
populate [] (Just IntTrie
it) = IntTrie
it
populate (ClassId
x:[ClassId]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId
x ClassId -> IntSet -> IntSet
`IS.insert` IntSet
k) forall a b. (a -> b) -> a -> b
$ forall a. (Maybe a -> Maybe a) -> ClassId -> IntMap a -> IntMap a
IM.alter (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs) ClassId
x IntMap IntTrie
m
{-# INLINABLE eGraphToDatabase #-}
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]
compileToQuery :: (Traversable lang) => Pattern lang -> (Query lang, Var)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery (VariablePattern ClassId
x) = (forall (lang :: * -> *). ClassId -> Query lang
SelectAllQuery ClassId
x, ClassId
x)
compileToQuery pa :: Pattern lang
pa@(NonVariablePattern lang (Pattern lang)
_) =
let ClassId
root :~ [Atom lang]
atoms = forall s a. State s a -> s -> a
evalState (forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux Pattern lang
pa) ClassId
0
in (forall (lang :: * -> *). [ClassId] -> [Atom lang] -> Query lang
Query ([ClassId] -> [ClassId]
nubInt forall a b. (a -> b) -> a -> b
$ ClassId
rootforall a. a -> [a] -> [a]
:forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars Pattern lang
pa) [Atom lang]
atoms, ClassId
root)
where
aux :: (Traversable lang) => Pattern lang -> State Int (AuxResult lang)
aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux (VariablePattern ClassId
x) = forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
x forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ [])
aux (NonVariablePattern lang (Pattern lang)
p) = do
ClassId
v <- forall (m :: * -> *) s. Monad m => StateT s m s
get
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (forall a. Num a => a -> a -> a
+ClassId
1)
(forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux lang (Pattern lang)
p
let boundVars :: [ClassId]
boundVars = forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
b :~ [Atom lang]
_) -> ClassId
b) [AuxResult lang]
auxs
atoms :: [Atom lang]
atoms = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
p' :: lang ClassId
p' = forall s a. State s a -> s -> a
evalState (forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p [ClassId]
boundVars) ClassId
0
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
v forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ (forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (ClassId -> ClassIdOrVar
CVar ClassId
v) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ClassId -> ClassIdOrVar
CVar lang ClassId
p')forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
where
subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p' [ClassId]
boundVars = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ ([ClassId]
boundVars forall a. [a] -> ClassId -> a
!!) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ClassId
i -> forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (forall a. Num a => a -> a -> a
+ClassId
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i)) lang (Pattern lang)
p'
vars :: Foldable lang => Pattern lang -> [Var]
vars :: forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars (VariablePattern ClassId
x) = [ClassId
x]
vars (NonVariablePattern lang (Pattern lang)
p) = [ClassId] -> [ClassId]
nubInt forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList lang (Pattern lang)
p
{-# INLINABLE compileToQuery #-}