{-# OPTIONS_HADDOCK show-extensions #-}
module ToySolver.EUF.EUFSolver
(
Solver
, newSolver
, FSym
, Term (..)
, ConstrID
, VAFun (..)
, newFSym
, newFun
, newConst
, assertEqual
, assertEqual'
, assertNotEqual
, assertNotEqual'
, check
, areEqual
, explain
, Entity
, EntityTuple
, Model (..)
, getModel
, eval
, evalAp
, pushBacktrackPoint
, popBacktrackPoint
, termToFlatTerm
, termToFSym
, fsymToTerm
, fsymToFlatTerm
, flatTermToFSym
) where
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef
import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC
data Solver
= Solver
{ Solver -> Solver
svCCSolver :: !CC.Solver
, Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
, Solver -> IORef IntSet
svExplanation :: IORef IntSet
, Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
}
newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
Solver
cc <- IO Solver
CC.newSolver
IORef (Map (Term, Term) (Maybe Level))
deqs <- forall a. a -> IO (IORef a)
newIORef forall k a. Map k a
Map.empty
IORef IntSet
expl <- forall a. a -> IO (IORef a)
newIORef forall a. HasCallStack => a
undefined
Vec (Map (Term, Term) ())
bp <- forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new
let solver :: Solver
solver =
Solver
{ svCCSolver :: Solver
svCCSolver = Solver
cc
, svDisequalities :: IORef (Map (Term, Term) (Maybe Level))
svDisequalities = IORef (Map (Term, Term) (Maybe Level))
deqs
, svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
, svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
}
forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver
newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO Level
newFSym Solver
solver = Solver -> IO Level
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)
newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)
newFun :: CC.VAFun a => Solver -> IO a
newFun :: forall a. VAFun a => Solver -> IO a
newFun Solver
solver = forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 forall a. Maybe a
Nothing
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = Solver -> Term -> Term -> Maybe Level -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe Level
cid
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 forall a. Maybe a
Nothing
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = if Term
t1 forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe Level -> IO ()
f (Term
t1,Term
t2) Maybe Level
cid else (Term, Term) -> Maybe Level -> IO ()
f (Term
t2,Term
t1) Maybe Level
cid
where
f :: (Term, Term) -> Maybe Level -> IO ()
f (Term, Term)
deq Maybe Level
cid = do
Map (Term, Term) (Maybe Level)
ds <- forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe Level)
ds) forall a b. (a -> b) -> a -> b
$ do
Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver (forall a b. (a, b) -> a
fst (Term, Term)
deq)
Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver (forall a b. (a, b) -> b
snd (Term, Term)
deq)
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) forall a b. (a -> b) -> a -> b
$! forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe Level
cid Map (Term, Term) (Maybe Level)
ds
Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Level
lvforall a. Eq a => a -> a -> Bool
==Level
0) forall a b. (a -> b) -> a -> b
$ do
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> Level -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (Level
lv forall a. Num a => a -> a -> a
- Level
1) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()
check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
Map (Term, Term) (Maybe Level)
ds <- forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. Either a b -> Bool
isRight forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe Level)
ds) forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe Level
cid) -> do
Bool
b <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
if Bool
b then do
Just IntSet
cs <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) forall a b. (a -> b) -> a -> b
$!
case Maybe Level
cid of
Maybe Level
Nothing -> IntSet
cs
Just Level
c -> Level -> IntSet -> IntSet
IntSet.insert Level
c IntSet
cs
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
else
forall (m :: * -> *) a. Monad m => a -> m a
return ()
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
case Maybe IntSet
ret of
Maybe IntSet
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
Just IntSet
cs -> forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs
getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
type Level = Int
getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO Level
getCurrentLevel Solver
solver = forall (a :: * -> * -> *) e. GenericVec a e -> IO Level
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) forall k a. Map k a
Map.empty
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
if Level
lvforall a. Eq a => a -> a -> Bool
==Level
0 then
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
else do
Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
Map (Term, Term) ()
xs <- forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) forall a b. (a -> b) -> a -> b
$ (forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)
termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO Level
termToFSym = Solver -> Term -> IO Level
CC.termToFSym forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> Level -> IO Term
fsymToTerm = Solver -> Level -> IO Term
CC.fsymToTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> Level -> IO FlatTerm
fsymToFlatTerm = Solver -> Level -> IO FlatTerm
CC.fsymToFlatTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO Level
flatTermToFSym = Solver -> FlatTerm -> IO Level
CC.flatTermToFSym forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver