{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}
module Cryptol.Symbolic where
import Control.Monad.IO.Class
import Control.Monad (replicateM, when, zipWithM, foldM)
import Control.Monad.Writer (WriterT, runWriterT, tell, lift)
import Data.List (intercalate, genericLength)
import Data.IORef(IORef)
import qualified Control.Exception as X
import qualified Data.SBV.Dynamic as SBV
import Data.SBV (Timing(SaveTiming))
import Data.SBV.Internals (showTDiff)
import qualified Cryptol.ModuleSystem as M hiding (getPrimMap)
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Base as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.Symbolic.Prims
import Cryptol.Symbolic.Value
import qualified Cryptol.Eval as Eval
import qualified Cryptol.Eval.Monad as Eval
import qualified Cryptol.Eval.Type as Eval
import qualified Cryptol.Eval.Value as Eval
import Cryptol.Eval.Env (GenEvalEnv(..))
import Cryptol.TypeCheck.AST
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.PP
import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Logger(logPutStrLn)
import Prelude ()
import Prelude.Compat
import Data.Time (NominalDiffTime)
type EvalEnv = GenEvalEnv SBool SWord
proverConfigs :: [(String, SBV.SMTConfig)]
proverConfigs =
[ ("cvc4" , SBV.cvc4 )
, ("yices" , SBV.yices )
, ("z3" , SBV.z3 )
, ("boolector", SBV.boolector)
, ("mathsat" , SBV.mathSAT )
, ("abc" , SBV.abc )
, ("offline" , SBV.defaultSMTCfg )
, ("any" , SBV.defaultSMTCfg )
]
proverNames :: [String]
proverNames = map fst proverConfigs
lookupProver :: String -> SBV.SMTConfig
lookupProver s =
case lookup s proverConfigs of
Just cfg -> cfg
Nothing -> panic "Cryptol.Symbolic" [ "invalid prover: " ++ s ]
type SatResult = [(Type, Expr, Eval.Value)]
data SatNum = AllSat | SomeSat Int
deriving (Show)
data QueryType = SatQuery SatNum | ProveQuery
deriving (Show)
data ProverCommand = ProverCommand {
pcQueryType :: QueryType
, pcProverName :: String
, pcVerbose :: Bool
, pcValidate :: Bool
, pcProverStats :: !(IORef ProverStats)
, pcExtraDecls :: [DeclGroup]
, pcSmtFile :: Maybe FilePath
, pcExpr :: Expr
, pcSchema :: Schema
}
type ProverStats = NominalDiffTime
data ProverResult = AllSatResult [SatResult]
| ThmResult [Type]
| EmptyResult
| ProverError String
satSMTResults :: SBV.SatResult -> [SBV.SMTResult]
satSMTResults (SBV.SatResult r) = [r]
allSatSMTResults :: SBV.AllSatResult -> [SBV.SMTResult]
allSatSMTResults (SBV.AllSatResult (_, _, _, rs)) = rs
thmSMTResults :: SBV.ThmResult -> [SBV.SMTResult]
thmSMTResults (SBV.ThmResult r) = [r]
proverError :: String -> M.ModuleCmd (Maybe SBV.Solver, ProverResult)
proverError msg (_,modEnv) =
return (Right ((Nothing, ProverError msg), modEnv), [])
satProve :: ProverCommand -> M.ModuleCmd (Maybe SBV.Solver, ProverResult)
satProve ProverCommand {..} =
protectStack proverError $ \(evo,modEnv) ->
M.runModuleM (evo,modEnv) $ do
let (isSat, mSatNum) = case pcQueryType of
ProveQuery -> (False, Nothing)
SatQuery sn -> case sn of
SomeSat n -> (True, Just n)
AllSat -> (True, Nothing)
let extDgs = allDeclGroups modEnv ++ pcExtraDecls
provers <-
case pcProverName of
"any" -> M.io SBV.sbvAvailableSolvers
_ -> return [(lookupProver pcProverName) { SBV.transcript = pcSmtFile
, SBV.allSatMaxModelCount = mSatNum
}]
let provers' = [ p { SBV.timing = SaveTiming pcProverStats
, SBV.verbose = pcVerbose
, SBV.validateModel = pcValidate
} | p <- provers ]
let tyFn = if isSat then existsFinType else forallFinType
let lPutStrLn = M.withLogger logPutStrLn
let doEval :: MonadIO m => Eval.Eval a -> m a
doEval m = liftIO $ Eval.runEval evo m
let runProver fn tag e = do
case provers of
[prover] -> do
when pcVerbose $
lPutStrLn $ "Trying proof with " ++
show (SBV.name (SBV.solver prover))
res <- M.io (fn prover e)
when pcVerbose $
lPutStrLn $ "Got result from " ++
show (SBV.name (SBV.solver prover))
return (Just (SBV.name (SBV.solver prover)), tag res)
_ ->
return ( Nothing
, [ SBV.ProofError
prover
[":sat with option prover=any requires option satNum=1"]
Nothing
| prover <- provers ]
)
runProvers fn tag e = do
when pcVerbose $
lPutStrLn $ "Trying proof with " ++
intercalate ", " (map (show . SBV.name . SBV.solver) provers)
(firstProver, timeElapsed, res) <- M.io (fn provers' e)
when pcVerbose $
lPutStrLn $ "Got result from " ++ show firstProver ++
", time: " ++ showTDiff timeElapsed
return (Just firstProver, tag res)
let runFn = case pcQueryType of
ProveQuery -> runProvers SBV.proveWithAny thmSMTResults
SatQuery sn -> case sn of
SomeSat 1 -> runProvers SBV.satWithAny satSMTResults
_ -> runProver SBV.allSatWith allSatSMTResults
let addAsm = case pcQueryType of
ProveQuery -> \x y -> SBV.svOr (SBV.svNot x) y
SatQuery _ -> \x y -> SBV.svAnd x y
case predArgTypes pcSchema of
Left msg -> return (Nothing, ProverError msg)
Right ts -> do when pcVerbose $ lPutStrLn "Simulating..."
v <- doEval $ do env <- Eval.evalDecls extDgs mempty
Eval.evalExpr env pcExpr
prims <- M.getPrimMap
runRes <- runFn $ do
(args, asms) <- runWriterT (mapM tyFn ts)
b <- doEval (fromVBit <$>
foldM fromVFun v (map Eval.ready args))
return (foldr addAsm b asms)
let (firstProver, results) = runRes
esatexprs <- case results of
(SBV.Satisfiable {} : _) -> do
tevss <- mapM mkTevs results
return $ AllSatResult tevss
where
mkTevs result = do
let Right (_, cvs) = SBV.getModelAssignment result
(vs, _) = parseValues ts cvs
sattys = unFinType <$> ts
satexprs <-
doEval (zipWithM (Eval.toExpr prims) sattys vs)
case zip3 sattys <$> (sequence satexprs) <*> pure vs of
Nothing ->
panic "Cryptol.Symbolic.sat"
[ "unable to make assignment into expression" ]
Just tevs -> return $ tevs
[SBV.Unsatisfiable {}] ->
return $ ThmResult (unFinType <$> ts)
[] -> return $ ThmResult (unFinType <$> ts)
_ -> return $ ProverError (rshow results)
where rshow | isSat = show . SBV.AllSatResult . (False,False,False,)
| otherwise = show . SBV.ThmResult . head
return (firstProver, esatexprs)
satProveOffline :: ProverCommand -> M.ModuleCmd (Either String String)
satProveOffline ProverCommand {..} =
protectStack (\msg (_,modEnv) -> return (Right (Left msg, modEnv), [])) $
\(evOpts,modEnv) -> do
let isSat = case pcQueryType of
ProveQuery -> False
SatQuery _ -> True
let extDgs = allDeclGroups modEnv ++ pcExtraDecls
let tyFn = if isSat then existsFinType else forallFinType
let addAsm = if isSat then SBV.svAnd else \x y -> SBV.svOr (SBV.svNot x) y
case predArgTypes pcSchema of
Left msg -> return (Right (Left msg, modEnv), [])
Right ts ->
do when pcVerbose $ logPutStrLn (Eval.evalLogger evOpts) "Simulating..."
v <- liftIO $ Eval.runEval evOpts $
do env <- Eval.evalDecls extDgs mempty
Eval.evalExpr env pcExpr
smtlib <- SBV.generateSMTBenchmark isSat $ do
(args, asms) <- runWriterT (mapM tyFn ts)
b <- liftIO $ Eval.runEval evOpts
(fromVBit <$> foldM fromVFun v (map Eval.ready args))
return (foldr addAsm b asms)
return (Right (Right smtlib, modEnv), [])
protectStack :: (String -> M.ModuleCmd a)
-> M.ModuleCmd a
-> M.ModuleCmd a
protectStack mkErr cmd modEnv =
X.catchJust isOverflow (cmd modEnv) handler
where isOverflow X.StackOverflow = Just ()
isOverflow _ = Nothing
msg = "Symbolic evaluation failed to terminate."
handler () = mkErr msg modEnv
parseValues :: [FinType] -> [SBV.CV] -> ([Eval.Value], [SBV.CV])
parseValues [] cvs = ([], cvs)
parseValues (t : ts) cvs = (v : vs, cvs'')
where (v, cvs') = parseValue t cvs
(vs, cvs'') = parseValues ts cvs'
parseValue :: FinType -> [SBV.CV] -> (Eval.Value, [SBV.CV])
parseValue FTBit [] = panic "Cryptol.Symbolic.parseValue" [ "empty FTBit" ]
parseValue FTBit (cv : cvs) = (Eval.VBit (SBV.cvToBool cv), cvs)
parseValue FTInteger cvs =
case SBV.genParse SBV.KUnbounded cvs of
Just (x, cvs') -> (Eval.VInteger x, cvs')
Nothing -> panic "Cryptol.Symbolic.parseValue" [ "no integer" ]
parseValue (FTIntMod _) cvs = parseValue FTInteger cvs
parseValue (FTSeq 0 FTBit) cvs = (Eval.word 0 0, cvs)
parseValue (FTSeq n FTBit) cvs =
case SBV.genParse (SBV.KBounded False n) cvs of
Just (x, cvs') -> (Eval.word (toInteger n) x, cvs')
Nothing -> (VWord (genericLength vs) $ return $ Eval.WordVal $
Eval.packWord (map fromVBit vs), cvs')
where (vs, cvs') = parseValues (replicate n FTBit) cvs
parseValue (FTSeq n t) cvs =
(Eval.VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
, cvs'
)
where (vs, cvs') = parseValues (replicate n t) cvs
parseValue (FTTuple ts) cvs = (Eval.VTuple (map Eval.ready vs), cvs')
where (vs, cvs') = parseValues ts cvs
parseValue (FTRecord fs) cvs = (Eval.VRecord (zip ns (map Eval.ready vs)), cvs')
where (ns, ts) = unzip fs
(vs, cvs') = parseValues ts cvs
allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups = concatMap mDecls . M.loadedNonParamModules
data FinType
= FTBit
| FTInteger
| FTIntMod Integer
| FTSeq Int FinType
| FTTuple [FinType]
| FTRecord [(Ident, FinType)]
numType :: Integer -> Maybe Int
numType n
| 0 <= n && n <= toInteger (maxBound :: Int) = Just (fromInteger n)
| otherwise = Nothing
finType :: TValue -> Maybe FinType
finType ty =
case ty of
Eval.TVBit -> Just FTBit
Eval.TVInteger -> Just FTInteger
Eval.TVIntMod n -> Just (FTIntMod n)
Eval.TVSeq n t -> FTSeq <$> numType n <*> finType t
Eval.TVTuple ts -> FTTuple <$> traverse finType ts
Eval.TVRec fields -> FTRecord <$> traverse (traverseSnd finType) fields
_ -> Nothing
unFinType :: FinType -> Type
unFinType fty =
case fty of
FTBit -> tBit
FTInteger -> tInteger
FTIntMod n -> tIntMod (tNum n)
FTSeq l ety -> tSeq (tNum l) (unFinType ety)
FTTuple ftys -> tTuple (unFinType <$> ftys)
FTRecord fs -> tRec (zip fns tys)
where
fns = fst <$> fs
tys = unFinType . snd <$> fs
predArgTypes :: Schema -> Either String [FinType]
predArgTypes schema@(Forall ts ps ty)
| null ts && null ps =
case go <$> (Eval.evalType mempty ty) of
Right (Just fts) -> Right fts
_ -> Left $ "Not a valid predicate type:\n" ++ show (pp schema)
| otherwise = Left $ "Not a monomorphic type:\n" ++ show (pp schema)
where
go :: TValue -> Maybe [FinType]
go Eval.TVBit = Just []
go (Eval.TVFun ty1 ty2) = (:) <$> finType ty1 <*> go ty2
go _ = Nothing
inBoundsIntMod :: Integer -> SInteger -> SBool
inBoundsIntMod n x =
SBV.svAnd (SBV.svLessEq (Eval.integerLit 0) x) (SBV.svLessThan x (Eval.integerLit n))
forallFinType :: FinType -> WriterT [SBool] SBV.Symbolic Value
forallFinType ty =
case ty of
FTBit -> VBit <$> lift forallSBool_
FTInteger -> VInteger <$> lift forallSInteger_
FTIntMod n -> do x <- lift forallSInteger_
tell [inBoundsIntMod n x]
return (VInteger x)
FTSeq 0 FTBit -> return $ Eval.word 0 0
FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> lift (forallBV_ n)
FTSeq n t -> do vs <- replicateM n (forallFinType t)
return $ VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
FTTuple ts -> VTuple <$> mapM (fmap Eval.ready . forallFinType) ts
FTRecord fs -> VRecord <$> mapM (traverseSnd (fmap Eval.ready . forallFinType)) fs
existsFinType :: FinType -> WriterT [SBool] SBV.Symbolic Value
existsFinType ty =
case ty of
FTBit -> VBit <$> lift existsSBool_
FTInteger -> VInteger <$> lift existsSInteger_
FTIntMod n -> do x <- lift existsSInteger_
tell [inBoundsIntMod n x]
return (VInteger x)
FTSeq 0 FTBit -> return $ Eval.word 0 0
FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> lift (existsBV_ n)
FTSeq n t -> do vs <- replicateM n (existsFinType t)
return $ VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
FTTuple ts -> VTuple <$> mapM (fmap Eval.ready . existsFinType) ts
FTRecord fs -> VRecord <$> mapM (traverseSnd (fmap Eval.ready . existsFinType)) fs