{-# OPTIONS_HADDOCK show-extensions #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.EUF.EUFSolver
-- Copyright   :  (c) Masahiro Sakai 2015
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.EUF.EUFSolver
  ( -- * The @Solver@ type
    Solver
  , newSolver

  -- * Problem description
  , FSym
  , Term (..)
  , ConstrID
  , VAFun (..)
  , newFSym
  , newFun
  , newConst
  , assertEqual
  , assertEqual'
  , assertNotEqual
  , assertNotEqual'

  -- * Query
  , check
  , areEqual

  -- * Explanation
  , explain

  -- * Model Construction
  , Entity
  , EntityTuple
  , Model (..)
  , getModel
  , eval
  , evalAp

  -- * Backtracking
  , pushBacktrackPoint
  , popBacktrackPoint

  -- * Low-level operations
  , termToFlatTerm
  , termToFSym
  , fsymToTerm
  , fsymToFlatTerm
  , flatTermToFSym
  ) where

import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef

import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC

data Solver
  = Solver
  { Solver -> Solver
svCCSolver :: !CC.Solver
  , Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
  , Solver -> IORef IntSet
svExplanation :: IORef IntSet
  , Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
  }

newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
  Solver
cc <- IO Solver
CC.newSolver
  IORef (Map (Term, Term) (Maybe Level))
deqs <- forall a. a -> IO (IORef a)
newIORef forall k a. Map k a
Map.empty
  IORef IntSet
expl <- forall a. a -> IO (IORef a)
newIORef forall a. HasCallStack => a
undefined
  Vec (Map (Term, Term) ())
bp <- forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new

  let solver :: Solver
solver =
        Solver
        { svCCSolver :: Solver
svCCSolver = Solver
cc
        , svDisequalities :: IORef (Map (Term, Term) (Maybe Level))
svDisequalities = IORef (Map (Term, Term) (Maybe Level))
deqs
        , svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
        , svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
        }
  forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver

newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO Level
newFSym Solver
solver = Solver -> IO Level
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)

newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)

newFun :: CC.VAFun a => Solver -> IO a
newFun :: forall a. VAFun a => Solver -> IO a
newFun Solver
solver = forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)

assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 forall a. Maybe a
Nothing

assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = Solver -> Term -> Term -> Maybe Level -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe Level
cid

assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 forall a. Maybe a
Nothing

assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = if Term
t1 forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe Level -> IO ()
f (Term
t1,Term
t2) Maybe Level
cid else (Term, Term) -> Maybe Level -> IO ()
f (Term
t2,Term
t1) Maybe Level
cid
  where
    f :: (Term, Term) -> Maybe Level -> IO ()
f (Term, Term)
deq Maybe Level
cid = do
      Map (Term, Term) (Maybe Level)
ds <- forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe Level)
ds) forall a b. (a -> b) -> a -> b
$ do
        Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver (forall a b. (a, b) -> a
fst (Term, Term)
deq) -- It is important to name the term for model generation
        Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver (forall a b. (a, b) -> b
snd (Term, Term)
deq) -- It is important to name the term for model generation
        forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) forall a b. (a -> b) -> a -> b
$! forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe Level
cid Map (Term, Term) (Maybe Level)
ds
        Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Level
lvforall a. Eq a => a -> a -> Bool
==Level
0) forall a b. (a -> b) -> a -> b
$ do
          forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> Level -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (Level
lv forall a. Num a => a -> a -> a
- Level
1) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()

check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
  Map (Term, Term) (Maybe Level)
ds <- forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
  forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. Either a b -> Bool
isRight forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe Level)
ds) forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe Level
cid) -> do
    Bool
b <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
    if Bool
b then do
      Just IntSet
cs <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) forall a b. (a -> b) -> a -> b
$!
        case Maybe Level
cid of
          Maybe Level
Nothing -> IntSet
cs
          Just Level
c -> Level -> IntSet -> IntSet
IntSet.insert Level
c IntSet
cs
      forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
    else
      forall (m :: * -> *) a. Monad m => a -> m a
return ()

areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2

explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
  Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
  case Maybe IntSet
ret of
    Maybe IntSet
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
    Just IntSet
cs -> forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs

-- -------------------------------------------------------------------
-- Model construction
-- -------------------------------------------------------------------

getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver

-- -------------------------------------------------------------------
-- Backtracking
-- -------------------------------------------------------------------

type Level = Int

getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO Level
getCurrentLevel Solver
solver = forall (a :: * -> * -> *) e. GenericVec a e -> IO Level
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)

pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
  Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
  forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) forall k a. Map k a
Map.empty

popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
  Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
  if Level
lvforall a. Eq a => a -> a -> Bool
==Level
0 then
    forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
  else do
    Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
    Map (Term, Term) ()
xs <- forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
    forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) forall a b. (a -> b) -> a -> b
$ (forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)

termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO Level
termToFSym     = Solver -> Term -> IO Level
CC.termToFSym     forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> Level -> IO Term
fsymToTerm     = Solver -> Level -> IO Term
CC.fsymToTerm     forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> Level -> IO FlatTerm
fsymToFlatTerm = Solver -> Level -> IO FlatTerm
CC.fsymToFlatTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO Level
flatTermToFSym = Solver -> FlatTerm -> IO Level
CC.flatTermToFSym forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver