{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-|
   Custom database implemented with trie-maps specialized to run conjunctive
   queries using a (worst-case optimal) generic join algorithm.

   Used in e-matching ('Data.Equality.Matching') as described by \"Relational
   E-Matching\" https://arxiv.org/abs/2108.02290.

   You probably don't need this module.
 -}
module Data.Equality.Matching.Database
  (
    genericJoin

  , Database(..)
  , Query(..)
  , IntTrie(..)
  , Subst
  , Var
  , Atom(..)
  , ClassIdOrVar(..)
  ) where

import Data.List (sortBy)
import Data.Function (on)
import Data.Maybe (mapMaybe)
import Control.Monad

import Data.Foldable as F (toList, foldl', length)
import qualified Data.Map.Strict    as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS

import Data.Equality.Graph.Classes.Id
import Data.Equality.Graph.Nodes
import Data.Equality.Language

-- | A variable in a query is identified by an 'Int'.
-- This is much more efficient than using e.g. a 'String'.
--
-- As a consequence, patterns also use 'Int' to represent a variable, but we
-- can still have an 'Data.String.IsString' instance for variable patterns by hashing the
-- string into a unique number.
type Var = Int

-- | Mapping from 'Var' to 'ClassId'. In a 'Subst' there is only one
-- substitution for each variable
type Subst = IM.IntMap ClassId

-- | A value which is either a 'ClassId' or a 'Var'
data ClassIdOrVar = CClassId {-# UNPACK #-} !ClassId
                  | CVar     {-# UNPACK #-} !Var
    deriving (Var -> ClassIdOrVar -> ShowS
[ClassIdOrVar] -> ShowS
ClassIdOrVar -> String
(Var -> ClassIdOrVar -> ShowS)
-> (ClassIdOrVar -> String)
-> ([ClassIdOrVar] -> ShowS)
-> Show ClassIdOrVar
forall a.
(Var -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Var -> ClassIdOrVar -> ShowS
showsPrec :: Var -> ClassIdOrVar -> ShowS
$cshow :: ClassIdOrVar -> String
show :: ClassIdOrVar -> String
$cshowList :: [ClassIdOrVar] -> ShowS
showList :: [ClassIdOrVar] -> ShowS
Show, ClassIdOrVar -> ClassIdOrVar -> Bool
(ClassIdOrVar -> ClassIdOrVar -> Bool)
-> (ClassIdOrVar -> ClassIdOrVar -> Bool) -> Eq ClassIdOrVar
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ClassIdOrVar -> ClassIdOrVar -> Bool
== :: ClassIdOrVar -> ClassIdOrVar -> Bool
$c/= :: ClassIdOrVar -> ClassIdOrVar -> Bool
/= :: ClassIdOrVar -> ClassIdOrVar -> Bool
Eq, Eq ClassIdOrVar
Eq ClassIdOrVar
-> (ClassIdOrVar -> ClassIdOrVar -> Ordering)
-> (ClassIdOrVar -> ClassIdOrVar -> Bool)
-> (ClassIdOrVar -> ClassIdOrVar -> Bool)
-> (ClassIdOrVar -> ClassIdOrVar -> Bool)
-> (ClassIdOrVar -> ClassIdOrVar -> Bool)
-> (ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar)
-> (ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar)
-> Ord ClassIdOrVar
ClassIdOrVar -> ClassIdOrVar -> Bool
ClassIdOrVar -> ClassIdOrVar -> Ordering
ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ClassIdOrVar -> ClassIdOrVar -> Ordering
compare :: ClassIdOrVar -> ClassIdOrVar -> Ordering
$c< :: ClassIdOrVar -> ClassIdOrVar -> Bool
< :: ClassIdOrVar -> ClassIdOrVar -> Bool
$c<= :: ClassIdOrVar -> ClassIdOrVar -> Bool
<= :: ClassIdOrVar -> ClassIdOrVar -> Bool
$c> :: ClassIdOrVar -> ClassIdOrVar -> Bool
> :: ClassIdOrVar -> ClassIdOrVar -> Bool
$c>= :: ClassIdOrVar -> ClassIdOrVar -> Bool
>= :: ClassIdOrVar -> ClassIdOrVar -> Bool
$cmax :: ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar
max :: ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar
$cmin :: ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar
min :: ClassIdOrVar -> ClassIdOrVar -> ClassIdOrVar
Ord)

-- | An 'Atom' 𝑅ᵢ(𝑣, 𝑣1, ..., 𝑣𝑘) is defined by the relation 𝑅ᵢ and by the
-- class-ids or variables 𝑣, 𝑣1, ..., 𝑣𝑘. It represents one conjunctive query's body atom.
data Atom lang
    = Atom
        !ClassIdOrVar        -- ^ Represents 𝑣
        !(lang ClassIdOrVar) -- ^ Represents 𝑅ᵢ(𝑣1, ..., 𝑣𝑘). Note how 𝑣 isn't included since the arity of the constructor is 𝑘 instead of 𝑘+1.

-- | A conjunctive query to be run on the database
data Query lang
    = Query ![Var] ![Atom lang]
    | SelectAllQuery {-# UNPACK #-} !Var

-- | The relational representation of an e-graph, as described in section 3.1
-- of \"Relational E-Matching\".
--
-- Every e-node with symbol 𝑓 in the e-graph corresponds to a tuple in the relation 𝑅𝑓 in the database.
-- If 𝑓 has arity 𝑘, then 𝑅𝑓 will have arity 𝑘 + 1; its first attribute is the e-class id that contains the
-- corresponding e-node , and the remaining attributes are the 𝑘 children of the 𝑓 e-node
--
-- For every existing symbol in the e-graph the 'Database' has a table.
--
-- In concrete, we map 'Operator's to 'IntTrie's -- each operator has one table
-- represented by an 'IntTrie'
newtype Database lang
    = DB (M.Map (Operator lang) IntTrie)

-- | An integer triemap that keeps a cache of all keys in at each level.
--
-- As described in the paper:
-- Generic join requires two important performance bounds to be met in order for its own run time
-- to meet the AGM bound. First, the intersection [...] must run in 𝑂 (min(|𝑅𝑗 .𝑥 |)) time. Second,
-- the residual relations should be computed in constant time, i.e., computing from the relation 𝑅(𝑥, 𝑦)
-- the relation 𝑅(𝑣𝑥 , 𝑦) for some 𝑣𝑥 ∈ 𝑅(𝑥, 𝑦).𝑥 must take constant time. Both of these can be solved by
-- using tries (sometimes called prefix or suffix trees) as an indexing data structure.
data IntTrie = MkIntTrie
  { IntTrie -> IntSet
tkeys :: !IS.IntSet
  , IntTrie -> IntMap IntTrie
trie :: !(IM.IntMap IntTrie)
  }


-- TODO use this somehow?
-- queryHeadVars :: Foldable lang => Query lang -> [Var]
-- queryHeadVars (SelectAllQuery x) = [x]
-- queryHeadVars (Query qv _) = qv
-- {-# INLINE queryHeadVars #-}

-- | Run a conjunctive 'Query' on a 'Database'
--
-- Produce the list of valid substitutions from query variables to the
-- query-matching class ids.
genericJoin :: forall l. Language l => Database l -> Query l -> [Subst]
-- ROMES:TODO a less ad-hoc/specialized implementation of generic join...
-- ROMES:TODO query ordering is very important!

-- We want to match against ANYTHING, so we return a valid substitution for
-- all existing e-class: get all relations and make a substition for each class in that relation, then join all substitutions across all classes
genericJoin :: forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin (DB Map (Operator l) IntTrie
m) (SelectAllQuery Var
x) = (IntTrie -> [Subst]) -> [IntTrie] -> [Subst]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Var -> Subst) -> [Var] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Var -> Var -> Subst
forall a. Var -> a -> IntMap a
IM.singleton Var
x) ([Var] -> [Subst]) -> (IntTrie -> [Var]) -> IntTrie -> [Subst]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Var]
IS.toList (IntSet -> [Var]) -> (IntTrie -> IntSet) -> IntTrie -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntTrie -> IntSet
tkeys) (Map (Operator l) IntTrie -> [IntTrie]
forall k a. Map k a -> [a]
M.elems Map (Operator l) IntTrie
m)

-- This is the last variable, so we return a valid substitution for every
-- possible value for the variable (hence, we prepend @x@ to each and make it
-- its own substitution)
-- ROMES:TODO: Start here. Map vars to indexs in an array and substitute in the resulting subst
genericJoin Database l
d q :: Query l
q@(Query [Var]
_ [Atom l]
atoms) = [Atom l] -> [Var] -> [Subst]
genericJoin' [Atom l]
atoms (Query l -> [Var]
forall (lang :: * -> *).
(Functor lang, Foldable lang) =>
Query lang -> [Var]
orderedVarsInQuery Query l
q)

 where
   genericJoin' :: [Atom l] -> [Var] -> [Subst]
   genericJoin' :: [Atom l] -> [Var] -> [Subst]
genericJoin' [Atom l]
atoms' = \case

     [] -> Atom l -> Subst
forall a. Monoid a => a
mempty (Atom l -> Subst) -> [Atom l] -> [Subst]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Atom l]
atoms'

     (!Var
x):[Var]
xs -> do

       Var
x_in_D <- Var -> [Atom l] -> [Var]
domainX Var
x [Atom l]
atoms'

       -- Each valid sub-query assumes x -> x_in_D substitution
       Subst
y <- [Atom l] -> [Var] -> [Subst]
genericJoin' (Var -> Var -> [Atom l] -> [Atom l]
forall (lang :: * -> *).
Functor lang =>
Var -> Var -> [Atom lang] -> [Atom lang]
substitute Var
x Var
x_in_D [Atom l]
atoms') [Var]
xs

       Subst -> [Subst]
forall a. a -> [a]
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> [Subst]) -> Subst -> [Subst]
forall a b. (a -> b) -> a -> b
$! Var -> Var -> Subst -> Subst
forall a. Var -> a -> IntMap a -> IntMap a
IM.insert Var
x Var
x_in_D Subst
y -- TODO: A bit contrieved, perhaps better to avoid map ?

   domainX :: Var -> [Atom l] -> [Int]
   domainX :: Var -> [Atom l] -> [Var]
domainX Var
x = IntSet -> [Var]
IS.toList (IntSet -> [Var]) -> ([Atom l] -> IntSet) -> [Atom l] -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> Database l -> [Atom l] -> IntSet
forall (l :: * -> *).
Language l =>
Var -> Database l -> [Atom l] -> IntSet
intersectAtoms Var
x Database l
d ([Atom l] -> IntSet)
-> ([Atom l] -> [Atom l]) -> [Atom l] -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Atom l -> Bool) -> [Atom l] -> [Atom l]
forall a. (a -> Bool) -> [a] -> [a]
filter (Var
x Var -> Atom l -> Bool
forall (lang :: * -> *).
(Functor lang, Foldable lang) =>
Var -> Atom lang -> Bool
`elemOfAtom`)
   {-# INLINE domainX #-}

   -- | Substitute all occurrences of 'Var' with given 'ClassId' in all given atoms.
   substitute :: Functor lang => Var -> ClassId -> [Atom lang] -> [Atom lang]
   substitute :: forall (lang :: * -> *).
Functor lang =>
Var -> Var -> [Atom lang] -> [Atom lang]
substitute Var
r Var
i = (Atom lang -> Atom lang) -> [Atom lang] -> [Atom lang]
forall a b. (a -> b) -> [a] -> [b]
map ((Atom lang -> Atom lang) -> [Atom lang] -> [Atom lang])
-> (Atom lang -> Atom lang) -> [Atom lang] -> [Atom lang]
forall a b. (a -> b) -> a -> b
$ \case
      Atom ClassIdOrVar
x lang ClassIdOrVar
l -> ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (if Var -> ClassIdOrVar
CVar Var
r ClassIdOrVar -> ClassIdOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdOrVar
x then Var -> ClassIdOrVar
CClassId Var
i else ClassIdOrVar
x) (lang ClassIdOrVar -> Atom lang) -> lang ClassIdOrVar -> Atom lang
forall a b. (a -> b) -> a -> b
$ (ClassIdOrVar -> ClassIdOrVar)
-> lang ClassIdOrVar -> lang ClassIdOrVar
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ClassIdOrVar
v -> if Var -> ClassIdOrVar
CVar Var
r ClassIdOrVar -> ClassIdOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdOrVar
v then Var -> ClassIdOrVar
CClassId Var
i else ClassIdOrVar
v) lang ClassIdOrVar
l

{-# INLINABLE genericJoin #-}

-- | Returns True if 'Var' occurs in given 'Atom'
elemOfAtom :: (Functor lang, Foldable lang) => Var -> Atom lang -> Bool
elemOfAtom :: forall (lang :: * -> *).
(Functor lang, Foldable lang) =>
Var -> Atom lang -> Bool
elemOfAtom !Var
x (Atom ClassIdOrVar
v lang ClassIdOrVar
l) = case ClassIdOrVar
v of
 CVar Var
v' -> Var
x Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
v'
 ClassIdOrVar
_ -> lang Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (lang Bool -> Bool) -> lang Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ClassIdOrVar -> Bool) -> lang ClassIdOrVar -> lang Bool
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ClassIdOrVar
v' -> Var -> ClassIdOrVar
CVar Var
x ClassIdOrVar -> ClassIdOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdOrVar
v') lang ClassIdOrVar
l

-- ROMES:TODO: Batching? How? https://arxiv.org/pdf/2108.02290.pdf

-- | Extract a list of unique variables from a 'Query', ordered by prioritizing
-- variables that occur in many relations, and secondly by prioritizing
-- variables that occur in small relations.
--
-- We use these heuristics because the variables' ordering is significant in
-- the query run-time performance.
--
-- This extraction could still be improved as some other strategies are
-- described in the paper (such as batching)
orderedVarsInQuery :: (Functor lang, Foldable lang) => Query lang -> [Var]
orderedVarsInQuery :: forall (lang :: * -> *).
(Functor lang, Foldable lang) =>
Query lang -> [Var]
orderedVarsInQuery (SelectAllQuery Var
x) = [Var
Item [Var]
x]
orderedVarsInQuery (Query [Var]
_ [Atom lang]
atoms) = IntSet -> [Var]
IS.toList (IntSet -> [Var]) -> ([Var] -> IntSet) -> [Var] -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> IntSet
IS.fromAscList ([Var] -> [Var]) -> [Var] -> [Var]
forall a b. (a -> b) -> a -> b
$ (Var -> Var -> Ordering) -> [Var] -> [Var]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Var -> Var -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Var -> Var -> Ordering) -> (Var -> Var) -> Var -> Var -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Var -> Var
varCost) ([Var] -> [Var]) -> [Var] -> [Var]
forall a b. (a -> b) -> a -> b
$ (ClassIdOrVar -> Maybe Var) -> [ClassIdOrVar] -> [Var]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ClassIdOrVar -> Maybe Var
toVar ([ClassIdOrVar] -> [Var]) -> [ClassIdOrVar] -> [Var]
forall a b. (a -> b) -> a -> b
$ ([ClassIdOrVar] -> Atom lang -> [ClassIdOrVar])
-> [ClassIdOrVar] -> [Atom lang] -> [ClassIdOrVar]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [ClassIdOrVar] -> Atom lang -> [ClassIdOrVar]
forall (lang :: * -> *).
Foldable lang =>
[ClassIdOrVar] -> Atom lang -> [ClassIdOrVar]
f [ClassIdOrVar]
forall a. Monoid a => a
mempty [Atom lang]
atoms
    where

        f :: Foldable lang => [ClassIdOrVar] -> Atom lang -> [ClassIdOrVar]
        f :: forall (lang :: * -> *).
Foldable lang =>
[ClassIdOrVar] -> Atom lang -> [ClassIdOrVar]
f [ClassIdOrVar]
s (Atom ClassIdOrVar
v (lang ClassIdOrVar -> [ClassIdOrVar]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [ClassIdOrVar]
l)) = ClassIdOrVar
vClassIdOrVar -> [ClassIdOrVar] -> [ClassIdOrVar]
forall a. a -> [a] -> [a]
:([ClassIdOrVar]
l [ClassIdOrVar] -> [ClassIdOrVar] -> [ClassIdOrVar]
forall a. Semigroup a => a -> a -> a
<> [ClassIdOrVar]
s)
        {-# INLINE f #-}

        -- First, prioritize variables that occur in many relations; second,
        -- prioritize variables that occur in small relations
        varCost :: Var -> Int
        varCost :: Var -> Var
varCost Var
v = (Var -> Atom lang -> Var) -> Var -> [Atom lang] -> Var
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Var
acc Atom lang
a -> if Var
v Var -> Atom lang -> Bool
forall (lang :: * -> *).
(Functor lang, Foldable lang) =>
Var -> Atom lang -> Bool
`elemOfAtom` Atom lang
a then Var
acc Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
100 Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Atom lang -> Var
forall (lang :: * -> *). Foldable lang => Atom lang -> Var
atomLength Atom lang
a else Var
acc) Var
0 [Atom lang]
atoms
        {-# INLINE varCost #-}

        -- | Get the size of an atom
        atomLength :: Foldable lang => Atom lang -> Int
        atomLength :: forall (lang :: * -> *). Foldable lang => Atom lang -> Var
atomLength (Atom ClassIdOrVar
_ lang ClassIdOrVar
l) = Var
1 Var -> Var -> Var
forall a. Num a => a -> a -> a
+ lang ClassIdOrVar -> Var
forall a. lang a -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
F.length lang ClassIdOrVar
l
        {-# INLINE atomLength #-}

        -- | Extract 'Var' from 'ClassIdOrVar'
        toVar :: ClassIdOrVar -> Maybe Var
        toVar :: ClassIdOrVar -> Maybe Var
toVar (CVar Var
v) = Var -> Maybe Var
forall a. a -> Maybe a
Just Var
v
        toVar (CClassId Var
_) = Maybe Var
forall a. Maybe a
Nothing
        {-# INLINE toVar #-}


-- ROMES:TODO Terrible name 'intersectAtoms'

-- | Given a database and a list of Atoms with an occurring var @x@, find
-- @D_x@, the domain of variable x, that is, the values x can take
--
-- Returns the class id set of classes forming the domain of var @x@
intersectAtoms :: forall l. Language l => Var -> Database l -> [Atom l] -> IS.IntSet
intersectAtoms :: forall (l :: * -> *).
Language l =>
Var -> Database l -> [Atom l] -> IntSet
intersectAtoms !Var
var (DB Map (Operator l) IntTrie
db) (Atom l
a:[Atom l]
atoms) = (Atom l -> IntSet -> IntSet) -> IntSet -> [Atom l] -> IntSet
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Atom l
x IntSet
xs -> (Atom l -> IntSet
f Atom l
x) IntSet -> IntSet -> IntSet
`IS.intersection` IntSet
xs) (Atom l -> IntSet
f Atom l
a) [Atom l]
atoms
  where
    -- Get the matching ids for an atom
    f :: Atom l -> IS.IntSet
    f :: Atom l -> IntSet
f (Atom ClassIdOrVar
v l ClassIdOrVar
l) = case Operator l -> Map (Operator l) IntTrie -> Maybe IntTrie
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (l () -> Operator l
forall (l :: * -> *). l () -> Operator l
Operator (l () -> Operator l) -> l () -> Operator l
forall a b. (a -> b) -> a -> b
$ l ClassIdOrVar -> l ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void l ClassIdOrVar
l) Map (Operator l) IntTrie
db of

        -- If needed relation doesn't exist altogether, return the matching
        -- class ids (none). When intersecting, nothing will be available -- as expected
        Maybe IntTrie
Nothing -> IntSet
forall a. Monoid a => a
mempty

        -- If needed relation does exist, find intersection in it
        -- Add list of found intersections to existing
        Just IntTrie
r  -> case Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie Var
var Subst
forall a. Monoid a => a
mempty IntTrie
r (ClassIdOrVar
vClassIdOrVar -> [ClassIdOrVar] -> [ClassIdOrVar]
forall a. a -> [a] -> [a]
:l ClassIdOrVar -> [ClassIdOrVar]
forall a. l a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList l ClassIdOrVar
l) of
                     Maybe IntSet
Nothing ->  String -> IntSet
forall a. HasCallStack => String -> a
error String
"intersectInTrie should return valid substitution for variable query"
                     Just IntSet
xs -> IntSet
xs

intersectAtoms Var
_ Database l
_ [] = String -> IntSet
forall a. HasCallStack => String -> a
error String
"can't intersect empty list of atoms?"

-- | Find the matching ids that a variable can take given a list of variables
-- and ids that must match the structure
--
-- Invalid substitutions are represented as Nothing
--
-- The intersection might be invalid while assuming values for variables. If
-- we're looking for the domain of some variables we should never get an
-- invalid substitution, but rather an empty list saying that the query
-- intersection is valid but empty.
--
--
-- If R_f(1,y,z), this function receives [1,y,z] :: [ClassIdOrVar] and
-- intersects the trie map of R_f with this prefix
--
-- TODO: write a note for this...
--
--
-- TODO: Really, a valid substitution is one which isn't empty...
intersectInTrie :: Var -- ^ The variable whose domain we are looking for
                -> IM.IntMap ClassId -- ^ A mapping from variables that have been substituted
                -> IntTrie  -- ^ The trie
                -> [ClassIdOrVar]  -- ^ The "query"
                -> Maybe IS.IntSet -- ^ The resulting domain for a valid substitution
intersectInTrie :: Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie !Var
var !Subst
substs (MkIntTrie IntSet
trieKeys IntMap IntTrie
m) = \case

    [] -> IntSet -> Maybe IntSet
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

    -- Looking for a class-id, so if it exists in map the intersection is
    -- valid and we simply continue the search for the domain
    CClassId Var
x:[ClassIdOrVar]
xs ->
        Var -> IntMap IntTrie -> Maybe IntTrie
forall a. Var -> IntMap a -> Maybe a
IM.lookup Var
x IntMap IntTrie
m Maybe IntTrie -> (IntTrie -> Maybe IntSet) -> Maybe IntSet
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \IntTrie
next -> Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie Var
var Subst
substs IntTrie
next [ClassIdOrVar]
xs

    -- Looking for a var. It might be one of the following:
    --
    --      (1) The variable whose domain we're looking for, and this is the
    --      first time we found it. In this case we'll assume all substitutions
    --      are valid, and try to get a valid substitution with that
    --      assumption. If the substitution is valid, the substitution is an
    --      element of the domain.
    --
    --      (2) The variable whose domain we're looking for, but we've already
    --      assumed a value for it in this branch, so we continue the recursion
    --      guaranteeing the assumption results in a valid substitution
    --
    --      (3) A bound variable, and this is the first time we find it. We
    --      assume its value for all branches and concatenate the result of all
    --      valid domain elements for each branch that resulted in a valid
    --      substitution
    --
    --      (4) A bound variable, but we've assumed a value for it, so we
    --      continue the recursion again to validate the assumption and
    --      possibly find the domain of the variable we're looking for ahead
    --
    CVar Var
x:[ClassIdOrVar]
xs -> case Var -> Subst -> Maybe Var
forall a. Var -> IntMap a -> Maybe a
IM.lookup Var
x Subst
substs of
        -- (2) or (4), we simply continue
        Just Var
varVal -> Var -> IntMap IntTrie -> Maybe IntTrie
forall a. Var -> IntMap a -> Maybe a
IM.lookup Var
varVal IntMap IntTrie
m Maybe IntTrie -> (IntTrie -> Maybe IntSet) -> Maybe IntSet
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \IntTrie
next -> Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie Var
var Subst
substs IntTrie
next [ClassIdOrVar]
xs
        -- (1) or (3)
        Maybe Var
Nothing -> IntSet -> Maybe IntSet
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntSet -> Maybe IntSet) -> IntSet -> Maybe IntSet
forall a b. (a -> b) -> a -> b
$ if Var
x Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
var
          -- (1)
          then
            -- If this is the var we're looking for, and the remaining @xs@
            -- suffix only consists of variables modulo the var we're looking
            -- for, we can simply return all possible keys for this since it is
            -- the correct variable. This is quite important for performance!
            if (ClassIdOrVar -> Bool) -> [ClassIdOrVar] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Var -> ClassIdOrVar -> Bool
isVarDifferentFrom Var
x) [ClassIdOrVar]
xs
              then IntSet
trieKeys
              else (Var -> IntTrie -> IntSet -> IntSet)
-> IntSet -> IntMap IntTrie -> IntSet
forall a b. (Var -> a -> b -> b) -> b -> IntMap a -> b
IM.foldrWithKey (\Var
k IntTrie
ls (!IntSet
acc) ->
               case Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie Var
var (Var -> Var -> Subst -> Subst
forall a. Var -> a -> IntMap a -> IntMap a
IM.insert Var
x Var
k Subst
substs) IntTrie
ls [ClassIdOrVar]
xs of
                   Maybe IntSet
Nothing -> IntSet
acc
                   Just IntSet
_  -> Var
k Var -> IntSet -> IntSet
`IS.insert` IntSet
acc
                         ) IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
m
          -- (3)
          -- else IS.unions $ IM.elems $ IM.mapMaybeWithKey (\k ls -> intersectInTrie var (IM.insert x k substs) ls xs) m
          else (Var -> IntTrie -> IntSet -> IntSet)
-> IntSet -> IntMap IntTrie -> IntSet
forall a b. (Var -> a -> b -> b) -> b -> IntMap a -> b
IM.foldrWithKey (\Var
k IntTrie
ls (!IntSet
acc) ->
            case Var -> Subst -> IntTrie -> [ClassIdOrVar] -> Maybe IntSet
intersectInTrie Var
var (Var -> Var -> Subst -> Subst
forall a. Var -> a -> IntMap a -> IntMap a
IM.insert Var
x Var
k Subst
substs) IntTrie
ls [ClassIdOrVar]
xs of
                Maybe IntSet
Nothing -> IntSet
acc
                Just IntSet
rs -> IntSet
rs IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
acc) IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
m
    where

      -- | Returns True if given 'ClassIdOrVar' holds a 'Var' and is different from given 'Var'.
      isVarDifferentFrom :: Var -> ClassIdOrVar -> Bool
      isVarDifferentFrom :: Var -> ClassIdOrVar -> Bool
isVarDifferentFrom Var
_ (CClassId Var
_) = Bool
False
      isVarDifferentFrom Var
x (CVar     Var
y) = Var
x Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
/= Var
y
      {-# INLINE isVarDifferentFrom #-}