{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Backend.SBV.Data.SMT.Solving
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Backend.SBV.Data.SMT.Solving
  ( ApproximationConfig (..),
    ExtraConfig (..),
    precise,
    approx,
    withTimeout,
    clearTimeout,
    withApprox,
    clearApprox,
    GrisetteSMTConfig (..),
    SolvingFailure (..),
    TermTy,
  )
where

import Control.DeepSeq
import Control.Exception
import Control.Monad.Except
import qualified Data.HashSet as S
import Data.Hashable
import Data.Kind
import Data.List (partition)
import Data.Maybe
import qualified Data.SBV as SBV
import Data.SBV.Control (Query)
import qualified Data.SBV.Control as SBVC
import GHC.TypeNats
import Grisette.Backend.SBV.Data.SMT.Lowering
import Grisette.Core.Data.BV
import Grisette.Core.Data.Class.Bool
import Grisette.Core.Data.Class.CEGISSolver
import Grisette.Core.Data.Class.Evaluate
import Grisette.Core.Data.Class.ExtractSymbolics
import Grisette.Core.Data.Class.GenSym
import Grisette.Core.Data.Class.ModelOps
import Grisette.Core.Data.Class.Solvable
import Grisette.Core.Data.Class.Solver
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
import Grisette.IR.SymPrim.Data.Prim.Model as PM
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool
import Grisette.IR.SymPrim.Data.SymPrim
import Grisette.IR.SymPrim.Data.TabularFun
import Language.Haskell.TH.Syntax (Lift)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> import Grisette.Backend.SBV
-- >>> import Data.Proxy

type Aux :: Bool -> Nat -> Type
type family Aux o n where
  Aux 'True n = SBV.SInteger
  Aux 'False n = SBV.SInt n

type IsZero :: Nat -> Bool
type family IsZero n where
  IsZero 0 = 'True
  IsZero _ = 'False

type TermTy :: Nat -> Type -> Type
type family TermTy bitWidth b where
  TermTy _ Bool = SBV.SBool
  TermTy n Integer = Aux (IsZero n) n
  TermTy n (IntN x) = SBV.SBV (SBV.IntN x)
  TermTy n (WordN x) = SBV.SBV (SBV.WordN x)
  TermTy n (a =-> b) = TermTy n a -> TermTy n b
  TermTy n (a --> b) = TermTy n a -> TermTy n b
  TermTy _ v = v

-- | Configures how to approximate unbounded values.
--
-- For example, if we use @'Approx' ('Data.Proxy' :: 'Data.Proxy' 4)@ to approximate the
-- following unbounded integer:
--
-- > (+ a 9)
--
-- We will get
--
-- > (bvadd a #x9)
--
-- Here the value 9 will be approximated to a 4-bit bit vector, and the
-- operation `bvadd` will be used instead of `+`.
--
-- Note that this approximation may not be sound. See 'GrisetteSMTConfig' for
-- more details.
data ApproximationConfig (n :: Nat) where
  NoApprox :: ApproximationConfig 0
  Approx :: (KnownNat n, IsZero n ~ 'False, SBV.BVIsNonZero n) => p n -> ApproximationConfig n

data ExtraConfig (i :: Nat) = ExtraConfig
  { -- | Timeout in milliseconds for each solver call. CEGIS may call the
    -- solver multiple times and each call has its own timeout.
    forall (i :: Nat). ExtraConfig i -> Maybe Int
timeout :: Maybe Int,
    -- | Configures how to approximate unbounded integer values.
    forall (i :: Nat). ExtraConfig i -> ApproximationConfig i
integerApprox :: ApproximationConfig i
  }

preciseExtraConfig :: ExtraConfig 0
preciseExtraConfig :: ExtraConfig 0
preciseExtraConfig =
  ExtraConfig
    { timeout :: Maybe Int
timeout = forall a. Maybe a
Nothing,
      integerApprox :: ApproximationConfig 0
integerApprox = ApproximationConfig 0
NoApprox
    }

approximateExtraConfig ::
  (KnownNat n, IsZero n ~ 'False, SBV.BVIsNonZero n) =>
  p n ->
  ExtraConfig n
approximateExtraConfig :: forall (n :: Nat) (p :: Nat -> *).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> ExtraConfig n
approximateExtraConfig p n
p =
  ExtraConfig
    { timeout :: Maybe Int
timeout = forall a. Maybe a
Nothing,
      integerApprox :: ApproximationConfig n
integerApprox = forall (n :: Nat) (p :: Nat -> *).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> ApproximationConfig n
Approx p n
p
    }

-- | Solver configuration for the Grisette SBV backend.
-- A Grisette solver configuration consists of a SBV solver configuration and
-- the reasoning precision.
--
-- Integers can be unbounded (mathematical integer) or bounded (machine
-- integer/bit vector). The two types of integers have their own use cases,
-- and should be used to model different systems.
-- However, the solvers are known to have bad performance on some unbounded
-- integer operations, for example, when reason about non-linear integer
-- algebraic (e.g., multiplication or division),
-- the solver may not be able to get a result in a reasonable time.
-- In contrast, reasoning about bounded integers is usually more efficient.
--
-- To bridge the performance gap between the two types of integers, Grisette
-- allows to model the system with unbounded integers, and evaluate them with
-- infinite precision during the symbolic evaluation, but when solving the
-- queries, they are translated to bit vectors for better performance.
--
-- For example, the Grisette term @5 * "a" :: 'SymInteger'@ should be translated
-- to the following SMT with the unbounded reasoning configuration (the term
-- is @t1@):
--
-- > (declare-fun a () Int)           ; declare symbolic constant a
-- > (define-fun c1 () Int 5)         ; define the concrete value 5
-- > (define-fun t1 () Int (* c1 a))  ; define the term
--
-- While with reasoning precision 4, it would be translated to the following
-- SMT (the term is @t1@):
--
-- > ; declare symbolic constant a, the type is a bit vector with bit width 4
-- > (declare-fun a () (_ BitVec 4))
-- > ; define the concrete value 1, translated to the bit vector #x1
-- > (define-fun c1 () (_ BitVec 4) #x5)
-- > ; define the term, using bit vector addition rather than integer addition
-- > (define-fun t1 () (_ BitVec 4) (bvmul c1 a))
--
-- This bounded translation can usually be solved faster than the unbounded
-- one, and should work well when no overflow is possible, in which case the
-- performance can be improved with almost no cost.
--
-- We must note that the bounded translation is an approximation and is __/not/__
-- __/sound/__. As the approximation happens only during the final translation,
-- the symbolic evaluation may aggressively optimize the term based on the
-- properties of mathematical integer arithmetic. This may cause the solver yield
-- results that is incorrect under both unbounded or bounded semantics.
--
-- The following is an example that is correct under bounded semantics, while is
-- incorrect under the unbounded semantics:
--
-- >>> :set -XTypeApplications -XOverloadedStrings -XDataKinds
-- >>> let a = "a" :: SymInteger
-- >>> solve (precise z3) $ a >~ 7 &&~ a <~ 9
-- Right (Model {a -> 8 :: Integer})
-- >>> solve (approx (Proxy @4) z3) $ a >~ 7 &&~ a <~ 9
-- Left Unsat
--
-- This may be avoided by setting an large enough reasoning precision to prevent
-- overflows.
data GrisetteSMTConfig (i :: Nat) = GrisetteSMTConfig {forall (i :: Nat). GrisetteSMTConfig i -> SMTConfig
sbvConfig :: SBV.SMTConfig, forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig :: ExtraConfig i}

-- | A precise reasoning configuration with the given SBV solver configuration.
precise :: SBV.SMTConfig -> GrisetteSMTConfig 0
precise :: SMTConfig -> GrisetteSMTConfig 0
precise SMTConfig
config = forall (i :: Nat).
SMTConfig -> ExtraConfig i -> GrisetteSMTConfig i
GrisetteSMTConfig SMTConfig
config ExtraConfig 0
preciseExtraConfig

-- | An approximate reasoning configuration with the given SBV solver configuration.
approx ::
  forall p n.
  (KnownNat n, IsZero n ~ 'False, SBV.BVIsNonZero n) =>
  p n ->
  SBV.SMTConfig ->
  GrisetteSMTConfig n
approx :: forall (p :: Nat -> *) (n :: Nat).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> SMTConfig -> GrisetteSMTConfig n
approx p n
p SMTConfig
config = forall (i :: Nat).
SMTConfig -> ExtraConfig i -> GrisetteSMTConfig i
GrisetteSMTConfig SMTConfig
config (forall (n :: Nat) (p :: Nat -> *).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> ExtraConfig n
approximateExtraConfig p n
p)

-- | Set the timeout for the solver configuration.
withTimeout :: Int -> GrisetteSMTConfig i -> GrisetteSMTConfig i
withTimeout :: forall (i :: Nat).
Int -> GrisetteSMTConfig i -> GrisetteSMTConfig i
withTimeout Int
t GrisetteSMTConfig i
config = GrisetteSMTConfig i
config {extraConfig :: ExtraConfig i
extraConfig = (forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig GrisetteSMTConfig i
config) {timeout :: Maybe Int
timeout = forall a. a -> Maybe a
Just Int
t}}

-- | Clear the timeout for the solver configuration.
clearTimeout :: GrisetteSMTConfig i -> GrisetteSMTConfig i
clearTimeout :: forall (i :: Nat). GrisetteSMTConfig i -> GrisetteSMTConfig i
clearTimeout GrisetteSMTConfig i
config = GrisetteSMTConfig i
config {extraConfig :: ExtraConfig i
extraConfig = (forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig GrisetteSMTConfig i
config) {timeout :: Maybe Int
timeout = forall a. Maybe a
Nothing}}

-- | Set the reasoning precision for the solver configuration.
withApprox :: (KnownNat n, IsZero n ~ 'False, SBV.BVIsNonZero n) => p n -> GrisetteSMTConfig i -> GrisetteSMTConfig n
withApprox :: forall (n :: Nat) (p :: Nat -> *) (i :: Nat).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> GrisetteSMTConfig i -> GrisetteSMTConfig n
withApprox p n
p GrisetteSMTConfig i
config = GrisetteSMTConfig i
config {extraConfig :: ExtraConfig n
extraConfig = (forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig GrisetteSMTConfig i
config) {integerApprox :: ApproximationConfig n
integerApprox = forall (n :: Nat) (p :: Nat -> *).
(KnownNat n, IsZero n ~ 'False, BVIsNonZero n) =>
p n -> ApproximationConfig n
Approx p n
p}}

-- | Clear the reasoning precision and perform precise reasoning with the
-- solver configuration.
clearApprox :: GrisetteSMTConfig i -> GrisetteSMTConfig 0
clearApprox :: forall (i :: Nat). GrisetteSMTConfig i -> GrisetteSMTConfig 0
clearApprox GrisetteSMTConfig i
config = GrisetteSMTConfig i
config {extraConfig :: ExtraConfig 0
extraConfig = (forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig GrisetteSMTConfig i
config) {integerApprox :: ApproximationConfig 0
integerApprox = ApproximationConfig 0
NoApprox}}

data SolvingFailure
  = DSat (Maybe String)
  | Unsat
  | Unk
  | ResultNumLimitReached
  | SolvingError SBV.SBVException
  deriving (Int -> SolvingFailure -> ShowS
[SolvingFailure] -> ShowS
SolvingFailure -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SolvingFailure] -> ShowS
$cshowList :: [SolvingFailure] -> ShowS
show :: SolvingFailure -> String
$cshow :: SolvingFailure -> String
showsPrec :: Int -> SolvingFailure -> ShowS
$cshowsPrec :: Int -> SolvingFailure -> ShowS
Show)

sbvCheckSatResult :: SBVC.CheckSatResult -> SolvingFailure
sbvCheckSatResult :: CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
SBVC.Sat = forall a. HasCallStack => String -> a
error String
"Should not happen"
sbvCheckSatResult (SBVC.DSat Maybe String
msg) = Maybe String -> SolvingFailure
DSat Maybe String
msg
sbvCheckSatResult CheckSatResult
SBVC.Unsat = SolvingFailure
Unsat
sbvCheckSatResult CheckSatResult
SBVC.Unk = SolvingFailure
Unk

applyTimeout :: GrisetteSMTConfig i -> Query a -> Query a
applyTimeout :: forall (i :: Nat) a. GrisetteSMTConfig i -> Query a -> Query a
applyTimeout GrisetteSMTConfig i
config Query a
q = case forall (i :: Nat). ExtraConfig i -> Maybe Int
timeout (forall (i :: Nat). GrisetteSMTConfig i -> ExtraConfig i
extraConfig GrisetteSMTConfig i
config) of
  Maybe Int
Nothing -> Query a
q
  Just Int
t -> forall a. Int -> Query a -> Query a
SBVC.timeout Int
t Query a
q

solveTermWith ::
  forall integerBitWidth.
  GrisetteSMTConfig integerBitWidth ->
  Term Bool ->
  IO (Either SolvingFailure PM.Model)
solveTermWith :: forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth
-> Term Bool -> IO (Either SolvingFailure Model)
solveTermWith GrisetteSMTConfig integerBitWidth
config Term Bool
term =
  forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBVException -> SolvingFailure
SolvingError) forall a b. (a -> b) -> a -> b
$
    forall a. SMTConfig -> Symbolic a -> IO a
SBV.runSMTWith (forall (i :: Nat). GrisetteSMTConfig i -> SMTConfig
sbvConfig GrisetteSMTConfig integerBitWidth
config) forall a b. (a -> b) -> a -> b
$ do
      (SymBiMap
m, SBool
a) <- forall (integerBitWidth :: Nat) a (m :: * -> *).
(HasCallStack, SBVFreshMonad m) =>
GrisetteSMTConfig integerBitWidth
-> Term a -> m (SymBiMap, TermTy integerBitWidth a)
lowerSinglePrim GrisetteSMTConfig integerBitWidth
config Term Bool
term
      forall a. Query a -> Symbolic a
SBVC.query forall a b. (a -> b) -> a -> b
$ forall (i :: Nat) a. GrisetteSMTConfig i -> Query a -> Query a
applyTimeout GrisetteSMTConfig integerBitWidth
config forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *). SolverContext m => SBool -> m ()
SBV.constrain SBool
a
        CheckSatResult
r <- Query CheckSatResult
SBVC.checkSat
        case CheckSatResult
r of
          CheckSatResult
SBVC.Sat -> do
            SMTModel
md <- Query SMTModel
SBVC.getModel
            forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth -> SMTModel -> SymBiMap -> Model
parseModel GrisetteSMTConfig integerBitWidth
config SMTModel
md SymBiMap
m)
          CheckSatResult
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
r)

instance Solver (GrisetteSMTConfig n) SolvingFailure where
  solve :: GrisetteSMTConfig n -> SymBool -> IO (Either SolvingFailure Model)
solve GrisetteSMTConfig n
config (SymBool Term Bool
t) = forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth
-> Term Bool -> IO (Either SolvingFailure Model)
solveTermWith GrisetteSMTConfig n
config Term Bool
t
  solveMulti :: GrisetteSMTConfig n
-> Int -> SymBool -> IO ([Model], SolvingFailure)
solveMulti GrisetteSMTConfig n
config Int
n s :: SymBool
s@(SymBool Term Bool
t)
    | Int
n forall a. Ord a => a -> a -> Bool
> Int
0 =
        forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle
          ( \(SBVException
x :: SBV.SBVException) -> do
              forall a. Show a => a -> IO ()
print String
"An SBV Exception occurred:"
              forall a. Show a => a -> IO ()
print SBVException
x
              forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$
                String
"Warning: Note that solveMulti do not fully support "
                  forall a. [a] -> [a] -> [a]
++ String
"timeouts, and will return an empty list if the solver"
                  forall a. [a] -> [a] -> [a]
++ String
"timeouts in any iteration."
              forall (m :: * -> *) a. Monad m => a -> m a
return ([], SBVException -> SolvingFailure
SolvingError SBVException
x)
          )
          forall a b. (a -> b) -> a -> b
$ forall a. SMTConfig -> Symbolic a -> IO a
SBV.runSMTWith (forall (i :: Nat). GrisetteSMTConfig i -> SMTConfig
sbvConfig GrisetteSMTConfig n
config)
          forall a b. (a -> b) -> a -> b
$ do
            (SymBiMap
newm, SBool
a) <- forall (integerBitWidth :: Nat) a (m :: * -> *).
(HasCallStack, SBVFreshMonad m) =>
GrisetteSMTConfig integerBitWidth
-> Term a -> m (SymBiMap, TermTy integerBitWidth a)
lowerSinglePrim GrisetteSMTConfig n
config Term Bool
t
            forall a. Query a -> Symbolic a
SBVC.query forall a b. (a -> b) -> a -> b
$ forall (i :: Nat) a. GrisetteSMTConfig i -> Query a -> Query a
applyTimeout GrisetteSMTConfig n
config forall a b. (a -> b) -> a -> b
$ do
              forall (m :: * -> *). SolverContext m => SBool -> m ()
SBV.constrain SBool
a
              CheckSatResult
r <- Query CheckSatResult
SBVC.checkSat
              case CheckSatResult
r of
                CheckSatResult
SBVC.Sat -> do
                  SMTModel
md <- Query SMTModel
SBVC.getModel
                  let model :: Model
model = forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth -> SMTModel -> SymBiMap -> Model
parseModel GrisetteSMTConfig n
config SMTModel
md SymBiMap
newm
                  Int -> Model -> SymBiMap -> Query ([Model], SolvingFailure)
remainingModels Int
n Model
model SymBiMap
newm
                CheckSatResult
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ([], CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
r)
    | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ([], SolvingFailure
ResultNumLimitReached)
    where
      allSymbols :: SymbolSet
allSymbols = forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics SymBool
s :: SymbolSet
      next :: PM.Model -> SymBiMap -> Query (SymBiMap, Either SBVC.CheckSatResult PM.Model)
      next :: Model -> SymBiMap -> Query (SymBiMap, Either CheckSatResult Model)
next Model
md SymBiMap
origm = do
        let newtm :: Term Bool
newtm =
              forall a b. (a -> b -> a) -> a -> HashSet b -> a
S.foldl'
                (\Term Bool
acc (SomeTypedSymbol TypeRep t
_ TypedSymbol t
v) -> Term Bool -> Term Bool -> Term Bool
pevalOrTerm Term Bool
acc (Term Bool -> Term Bool
pevalNotTerm (forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall a. TypedSymbol a -> Model -> Maybe (Term Bool)
equation TypedSymbol t
v Model
md)))
                (forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm Bool
False)
                (SymbolSet -> HashSet SomeTypedSymbol
unSymbolSet SymbolSet
allSymbols)
        (SymBiMap
newm, SBool
lowered) <- forall (integerBitWidth :: Nat) a (m :: * -> *).
(HasCallStack, SBVFreshMonad m) =>
GrisetteSMTConfig integerBitWidth
-> Term a -> SymBiMap -> m (SymBiMap, TermTy integerBitWidth a)
lowerSinglePrimCached GrisetteSMTConfig n
config Term Bool
newtm SymBiMap
origm
        forall (m :: * -> *). SolverContext m => SBool -> m ()
SBV.constrain SBool
lowered
        CheckSatResult
r <- Query CheckSatResult
SBVC.checkSat
        case CheckSatResult
r of
          CheckSatResult
SBVC.Sat -> do
            SMTModel
md1 <- Query SMTModel
SBVC.getModel
            let model :: Model
model = forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth -> SMTModel -> SymBiMap -> Model
parseModel GrisetteSMTConfig n
config SMTModel
md1 SymBiMap
newm
            forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
newm, forall a b. b -> Either a b
Right Model
model)
          CheckSatResult
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
newm, forall a b. a -> Either a b
Left CheckSatResult
r)
      remainingModels :: Int -> PM.Model -> SymBiMap -> Query ([PM.Model], SolvingFailure)
      remainingModels :: Int -> Model -> SymBiMap -> Query ([Model], SolvingFailure)
remainingModels Int
n1 Model
md SymBiMap
origm
        | Int
n1 forall a. Ord a => a -> a -> Bool
> Int
1 = do
            (SymBiMap
newm, Either CheckSatResult Model
r) <- Model -> SymBiMap -> Query (SymBiMap, Either CheckSatResult Model)
next Model
md SymBiMap
origm
            case Either CheckSatResult Model
r of
              Left CheckSatResult
r -> forall (m :: * -> *) a. Monad m => a -> m a
return ([Model
md], CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
r)
              Right Model
mo -> do
                ([Model]
rmmd, SolvingFailure
e) <- Int -> Model -> SymBiMap -> Query ([Model], SolvingFailure)
remainingModels (Int
n1 forall a. Num a => a -> a -> a
- Int
1) Model
mo SymBiMap
newm
                forall (m :: * -> *) a. Monad m => a -> m a
return (Model
md forall a. a -> [a] -> [a]
: [Model]
rmmd, SolvingFailure
e)
        | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ([Model
md], SolvingFailure
ResultNumLimitReached)
  solveAll :: GrisetteSMTConfig n -> SymBool -> IO [Model]
solveAll = forall a. HasCallStack => a
undefined

instance CEGISSolver (GrisetteSMTConfig n) SolvingFailure where
  cegisMultiInputs ::
    forall inputs spec.
    (ExtractSymbolics inputs, EvaluateSym inputs) =>
    GrisetteSMTConfig n ->
    [inputs] ->
    (inputs -> CEGISCondition) ->
    IO ([inputs], Either SolvingFailure PM.Model)
  cegisMultiInputs :: forall inputs spec.
(ExtractSymbolics inputs, EvaluateSym inputs) =>
GrisetteSMTConfig n
-> [inputs]
-> (inputs -> CEGISCondition)
-> IO ([inputs], Either SolvingFailure Model)
cegisMultiInputs GrisetteSMTConfig n
config [inputs]
inputs inputs -> CEGISCondition
func =
    case [inputs]
symInputs of
      [] -> do
        Either SolvingFailure Model
m <- forall config failure.
Solver config failure =>
config -> SymBool -> IO (Either failure Model)
solve GrisetteSMTConfig n
config ([inputs] -> SymBool
cexesAssertFun [inputs]
conInputs)
        forall (m :: * -> *) a. Monad m => a -> m a
return ([inputs]
conInputs, Either SolvingFailure Model
m)
      [inputs]
_ ->
        forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle
          ( \(SBVException
x :: SBV.SBVException) -> do
              forall a. Show a => a -> IO ()
print String
"An SBV Exception occurred:"
              forall a. Show a => a -> IO ()
print SBVException
x
              forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$
                String
"Warning: Note that CEGIS procedures do not fully support "
                  forall a. [a] -> [a] -> [a]
++ String
"timeouts, and will return an empty counter example list if "
                  forall a. [a] -> [a] -> [a]
++ String
"the solver timeouts during guessing phase."
              forall (m :: * -> *) a. Monad m => a -> m a
return ([], forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ SBVException -> SolvingFailure
SolvingError SBVException
x)
          )
          forall a b. (a -> b) -> a -> b
$ SymBool
-> [inputs]
-> Model
-> [inputs]
-> SymBool
-> SymBool
-> [inputs]
-> IO ([inputs], Either SolvingFailure Model)
go1 ([inputs] -> SymBool
cexesAssertFun [inputs]
conInputs) [inputs]
conInputs (forall a. HasCallStack => String -> a
error String
"Should have at least one gen") [] (forall c t. Solvable c t => c -> t
con Bool
True) (forall c t. Solvable c t => c -> t
con Bool
True) [inputs]
symInputs
    where
      ([inputs]
conInputs, [inputs]
symInputs) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall symbolSet (typedSymbol :: * -> *).
SymbolSetOps symbolSet typedSymbol =>
symbolSet -> Bool
isEmptySet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics) [inputs]
inputs
      go1 :: SymBool -> [inputs] -> PM.Model -> [inputs] -> SymBool -> SymBool -> [inputs] -> IO ([inputs], Either SolvingFailure PM.Model)
      go1 :: SymBool
-> [inputs]
-> Model
-> [inputs]
-> SymBool
-> SymBool
-> [inputs]
-> IO ([inputs], Either SolvingFailure Model)
go1 SymBool
cexFormula [inputs]
cexes Model
previousModel [inputs]
inputs SymBool
pre SymBool
post [inputs]
remainingSymInputs = do
        case [inputs]
remainingSymInputs of
          [] -> forall (m :: * -> *) a. Monad m => a -> m a
return ([inputs]
cexes, forall a b. b -> Either a b
Right Model
previousModel)
          inputs
newInput : [inputs]
vs -> do
            let CEGISCondition SymBool
nextPre SymBool
nextPost = inputs -> CEGISCondition
func inputs
newInput
            let finalPre :: SymBool
finalPre = SymBool
pre forall b. LogicalOp b => b -> b -> b
&&~ SymBool
nextPre
            let finalPost :: SymBool
finalPost = SymBool
post forall b. LogicalOp b => b -> b -> b
&&~ SymBool
nextPost
            ([inputs], Either SolvingFailure Model)
r <- SymBool
-> inputs
-> [inputs]
-> SymBool
-> SymBool
-> IO ([inputs], Either SolvingFailure Model)
go SymBool
cexFormula inputs
newInput (inputs
newInput forall a. a -> [a] -> [a]
: [inputs]
inputs) SymBool
finalPre SymBool
finalPost
            case ([inputs], Either SolvingFailure Model)
r of
              ([inputs]
newCexes, Left SolvingFailure
failure) -> forall (m :: * -> *) a. Monad m => a -> m a
return ([inputs]
cexes forall a. [a] -> [a] -> [a]
++ [inputs]
newCexes, forall a b. a -> Either a b
Left SolvingFailure
failure)
              ([inputs]
newCexes, Right Model
mo) -> do
                SymBool
-> [inputs]
-> Model
-> [inputs]
-> SymBool
-> SymBool
-> [inputs]
-> IO ([inputs], Either SolvingFailure Model)
go1
                  (SymBool
cexFormula forall b. LogicalOp b => b -> b -> b
&&~ [inputs] -> SymBool
cexesAssertFun [inputs]
newCexes)
                  ([inputs]
cexes forall a. [a] -> [a] -> [a]
++ [inputs]
newCexes)
                  Model
mo
                  (inputs
newInput forall a. a -> [a] -> [a]
: [inputs]
inputs)
                  SymBool
finalPre
                  SymBool
finalPost
                  [inputs]
vs
      cexAssertFun :: inputs -> SymBool
cexAssertFun inputs
input =
        let CEGISCondition SymBool
pre SymBool
post = inputs -> CEGISCondition
func inputs
input in SymBool
pre forall b. LogicalOp b => b -> b -> b
&&~ SymBool
post
      cexesAssertFun :: [inputs] -> SymBool
      cexesAssertFun :: [inputs] -> SymBool
cexesAssertFun = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\SymBool
acc inputs
x -> SymBool
acc forall b. LogicalOp b => b -> b -> b
&&~ inputs -> SymBool
cexAssertFun inputs
x) (forall c t. Solvable c t => c -> t
con Bool
True)
      go ::
        SymBool ->
        inputs ->
        [inputs] ->
        SymBool ->
        SymBool ->
        IO ([inputs], Either SolvingFailure PM.Model)
      go :: SymBool
-> inputs
-> [inputs]
-> SymBool
-> SymBool
-> IO ([inputs], Either SolvingFailure Model)
go SymBool
cexFormula inputs
inputs [inputs]
allInputs SymBool
pre SymBool
post =
        forall a. SMTConfig -> Symbolic a -> IO a
SBV.runSMTWith (forall (i :: Nat). GrisetteSMTConfig i -> SMTConfig
sbvConfig GrisetteSMTConfig n
config) forall a b. (a -> b) -> a -> b
$ do
          let SymBool Term Bool
t = SymBool
phi forall b. LogicalOp b => b -> b -> b
&&~ SymBool
cexFormula
          (SymBiMap
newm, SBool
a) <- forall (integerBitWidth :: Nat) a (m :: * -> *).
(HasCallStack, SBVFreshMonad m) =>
GrisetteSMTConfig integerBitWidth
-> Term a -> m (SymBiMap, TermTy integerBitWidth a)
lowerSinglePrim GrisetteSMTConfig n
config Term Bool
t
          forall a. Query a -> Symbolic a
SBVC.query forall a b. (a -> b) -> a -> b
$
            forall (i :: Nat) a. GrisetteSMTConfig i -> Query a -> Query a
applyTimeout GrisetteSMTConfig n
config forall a b. (a -> b) -> a -> b
$
              forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
                forall (m :: * -> *). SolverContext m => SBool -> m ()
SBV.constrain SBool
a
                CheckSatResult
r <- Query CheckSatResult
SBVC.checkSat
                Either SolvingFailure Model
mr <- case CheckSatResult
r of
                  CheckSatResult
SBVC.Sat -> do
                    SMTModel
md <- Query SMTModel
SBVC.getModel
                    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth -> SMTModel -> SymBiMap -> Model
parseModel GrisetteSMTConfig n
config SMTModel
md SymBiMap
newm
                  CheckSatResult
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
r
                Either SolvingFailure Model
-> [inputs]
-> SymBiMap
-> QueryT IO (SymBiMap, ([inputs], Either SolvingFailure Model))
loop ((SymbolSet
forallSymbols forall model symbolSet (typedSymbol :: * -> *).
ModelOps model symbolSet typedSymbol =>
symbolSet -> model -> model
`exceptFor`) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either SolvingFailure Model
mr) [] SymBiMap
newm
        where
          forallSymbols :: SymbolSet
          forallSymbols :: SymbolSet
forallSymbols = forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics [inputs]
allInputs
          phi :: SymBool
phi = SymBool
pre forall b. LogicalOp b => b -> b -> b
&&~ SymBool
post
          negphi :: SymBool
negphi = SymBool
pre forall b. LogicalOp b => b -> b -> b
&&~ forall b. LogicalOp b => b -> b
nots SymBool
post
          check :: Model -> IO (Either SolvingFailure (inputs, PM.Model))
          check :: Model -> IO (Either SolvingFailure (inputs, Model))
check Model
candidate = do
            let evaluated :: SymBool
evaluated = forall a. EvaluateSym a => Bool -> Model -> a -> a
evaluateSym Bool
False Model
candidate SymBool
negphi
            Either SolvingFailure Model
r <- forall config failure.
Solver config failure =>
config -> SymBool -> IO (Either failure Model)
solve GrisetteSMTConfig n
config SymBool
evaluated
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ do
              Model
m <- Either SolvingFailure Model
r
              let newm :: Model
newm = forall model symbolSet (typedSymbol :: * -> *).
ModelOps model symbolSet typedSymbol =>
symbolSet -> model -> model
exact SymbolSet
forallSymbols Model
m
              forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. EvaluateSym a => Bool -> Model -> a -> a
evaluateSym Bool
False Model
newm inputs
inputs, Model
newm)
          guess :: Model -> SymBiMap -> Query (SymBiMap, Either SolvingFailure PM.Model)
          guess :: Model -> SymBiMap -> Query (SymBiMap, Either SolvingFailure Model)
guess Model
candidate SymBiMap
origm = do
            let SymBool Term Bool
evaluated = forall a. EvaluateSym a => Bool -> Model -> a -> a
evaluateSym Bool
False Model
candidate SymBool
phi
            (SymBiMap
newm, SBool
lowered) <- forall (integerBitWidth :: Nat) a (m :: * -> *).
(HasCallStack, SBVFreshMonad m) =>
GrisetteSMTConfig integerBitWidth
-> Term a -> SymBiMap -> m (SymBiMap, TermTy integerBitWidth a)
lowerSinglePrimCached GrisetteSMTConfig n
config Term Bool
evaluated SymBiMap
origm
            forall (m :: * -> *). SolverContext m => SBool -> m ()
SBV.constrain SBool
lowered
            CheckSatResult
r <- Query CheckSatResult
SBVC.checkSat
            case CheckSatResult
r of
              CheckSatResult
SBVC.Sat -> do
                SMTModel
md <- Query SMTModel
SBVC.getModel
                let model :: Model
model = forall (integerBitWidth :: Nat).
GrisetteSMTConfig integerBitWidth -> SMTModel -> SymBiMap -> Model
parseModel GrisetteSMTConfig n
config SMTModel
md SymBiMap
newm
                forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
newm, forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall model symbolSet (typedSymbol :: * -> *).
ModelOps model symbolSet typedSymbol =>
symbolSet -> model -> model
exceptFor SymbolSet
forallSymbols Model
model)
              CheckSatResult
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
newm, forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ CheckSatResult -> SolvingFailure
sbvCheckSatResult CheckSatResult
r)
          loop ::
            Either SolvingFailure PM.Model ->
            [inputs] ->
            SymBiMap ->
            Query (SymBiMap, ([inputs], Either SolvingFailure PM.Model))
          loop :: Either SolvingFailure Model
-> [inputs]
-> SymBiMap
-> QueryT IO (SymBiMap, ([inputs], Either SolvingFailure Model))
loop (Right Model
mo) [inputs]
cexes SymBiMap
origm = do
            Either SolvingFailure (inputs, Model)
r <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Model -> IO (Either SolvingFailure (inputs, Model))
check Model
mo
            case Either SolvingFailure (inputs, Model)
r of
              Left SolvingFailure
Unsat -> forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
origm, ([inputs]
cexes, forall a b. b -> Either a b
Right Model
mo))
              Left SolvingFailure
v -> forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
origm, ([inputs]
cexes, forall a b. a -> Either a b
Left SolvingFailure
v))
              Right (inputs
cex, Model
cexm) -> do
                (SymBiMap
newm, Either SolvingFailure Model
res) <- Model -> SymBiMap -> Query (SymBiMap, Either SolvingFailure Model)
guess Model
cexm SymBiMap
origm
                Either SolvingFailure Model
-> [inputs]
-> SymBiMap
-> QueryT IO (SymBiMap, ([inputs], Either SolvingFailure Model))
loop Either SolvingFailure Model
res (inputs
cex forall a. a -> [a] -> [a]
: [inputs]
cexes) SymBiMap
newm
          loop (Left SolvingFailure
v) [inputs]
cexes SymBiMap
origm = forall (m :: * -> *) a. Monad m => a -> m a
return (SymBiMap
origm, ([inputs]
cexes, forall a b. a -> Either a b
Left SolvingFailure
v))

newtype CegisInternal = CegisInternal Int
  deriving (CegisInternal -> CegisInternal -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CegisInternal -> CegisInternal -> Bool
$c/= :: CegisInternal -> CegisInternal -> Bool
== :: CegisInternal -> CegisInternal -> Bool
$c== :: CegisInternal -> CegisInternal -> Bool
Eq, Int -> CegisInternal -> ShowS
[CegisInternal] -> ShowS
CegisInternal -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CegisInternal] -> ShowS
$cshowList :: [CegisInternal] -> ShowS
show :: CegisInternal -> String
$cshow :: CegisInternal -> String
showsPrec :: Int -> CegisInternal -> ShowS
$cshowsPrec :: Int -> CegisInternal -> ShowS
Show, Eq CegisInternal
CegisInternal -> CegisInternal -> Bool
CegisInternal -> CegisInternal -> Ordering
CegisInternal -> CegisInternal -> CegisInternal
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: CegisInternal -> CegisInternal -> CegisInternal
$cmin :: CegisInternal -> CegisInternal -> CegisInternal
max :: CegisInternal -> CegisInternal -> CegisInternal
$cmax :: CegisInternal -> CegisInternal -> CegisInternal
>= :: CegisInternal -> CegisInternal -> Bool
$c>= :: CegisInternal -> CegisInternal -> Bool
> :: CegisInternal -> CegisInternal -> Bool
$c> :: CegisInternal -> CegisInternal -> Bool
<= :: CegisInternal -> CegisInternal -> Bool
$c<= :: CegisInternal -> CegisInternal -> Bool
< :: CegisInternal -> CegisInternal -> Bool
$c< :: CegisInternal -> CegisInternal -> Bool
compare :: CegisInternal -> CegisInternal -> Ordering
$ccompare :: CegisInternal -> CegisInternal -> Ordering
Ord, forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => CegisInternal -> m Exp
forall (m :: * -> *).
Quote m =>
CegisInternal -> Code m CegisInternal
liftTyped :: forall (m :: * -> *).
Quote m =>
CegisInternal -> Code m CegisInternal
$cliftTyped :: forall (m :: * -> *).
Quote m =>
CegisInternal -> Code m CegisInternal
lift :: forall (m :: * -> *). Quote m => CegisInternal -> m Exp
$clift :: forall (m :: * -> *). Quote m => CegisInternal -> m Exp
Lift)
  deriving newtype (Eq CegisInternal
Int -> CegisInternal -> Int
CegisInternal -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: CegisInternal -> Int
$chash :: CegisInternal -> Int
hashWithSalt :: Int -> CegisInternal -> Int
$chashWithSalt :: Int -> CegisInternal -> Int
Hashable, CegisInternal -> ()
forall a. (a -> ()) -> NFData a
rnf :: CegisInternal -> ()
$crnf :: CegisInternal -> ()
NFData)