{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}

-- | A module providing a backend that launches solvers as external processes.
module SMTLIB.Backends.Process
  ( Config (..),
    Handle (..),
    defaultConfig,
    new,
    close,
    with,
    toBackend,
    P.StdStream (..),
  )
where

import qualified Control.Exception as X
import Data.ByteString.Builder
  ( Builder,
    byteString,
    hPutBuilder,
    toLazyByteString,
  )
import qualified Data.ByteString.Char8 as BS
import GHC.IO.Exception (IOException (ioe_description))
import SMTLIB.Backends (Backend (..))
import qualified System.IO as IO
import qualified System.Process as P

data Config = Config
  { -- | The command to call to run the solver.
    Config -> String
exe :: String,
    -- | Arguments to pass to the solver's command.
    Config -> [String]
args :: [String],
    -- | How to handle std_err of the solver.
    Config -> StdStream
std_err :: P.StdStream
  }

-- | By default, use Z3 as an external process and ignores log messages.
defaultConfig :: Config
-- if you change this, make sure to also update the comment two lines above
-- as well as the one in @smtlib-backends-process/tests/Examples.hs@
defaultConfig :: Config
defaultConfig = String -> [String] -> StdStream -> Config
Config String
"z3" [String
"-in"] StdStream
P.CreatePipe

data Handle = Handle
  { -- | The process running the solver.
    Handle -> ProcessHandle
process :: P.ProcessHandle,
    -- | The input channel of the process.
    Handle -> Handle
hIn :: IO.Handle,
    -- | The output channel of the process.
    Handle -> Handle
hOut :: IO.Handle,
    -- | The error channel of the process.
    Handle -> Maybe Handle
hMaybeErr :: Maybe IO.Handle
  }

-- | Run a solver as a process.
new ::
  -- | The solver process' configuration.
  Config ->
  IO Handle
new :: Config -> IO Handle
new Config {String
[String]
StdStream
std_err :: StdStream
args :: [String]
exe :: String
std_err :: Config -> StdStream
args :: Config -> [String]
exe :: Config -> String
..} = forall a. String -> IO a -> IO a
decorateIOError String
"creating the solver process" forall a b. (a -> b) -> a -> b
$ do
  (Just Handle
hIn, Just Handle
hOut, Maybe Handle
hMaybeErr, ProcessHandle
process) <-
    CreateProcess
-> IO (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle)
P.createProcess
      (String -> [String] -> CreateProcess
P.proc String
exe [String]
args)
        { std_in :: StdStream
P.std_in = StdStream
P.CreatePipe,
          std_out :: StdStream
P.std_out = StdStream
P.CreatePipe,
          std_err :: StdStream
P.std_err = StdStream
std_err
        }
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Handle -> IO ()
setupHandle [Handle
hIn, Handle
hOut]
  -- log error messages created by the backend
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ProcessHandle -> Handle -> Handle -> Maybe Handle -> Handle
Handle ProcessHandle
process Handle
hIn Handle
hOut Maybe Handle
hMaybeErr
  where
    setupHandle :: Handle -> IO ()
setupHandle Handle
h = do
      Handle -> Bool -> IO ()
IO.hSetBinaryMode Handle
h Bool
True
      Handle -> BufferMode -> IO ()
IO.hSetBuffering Handle
h forall a b. (a -> b) -> a -> b
$ Maybe Int -> BufferMode
IO.BlockBuffering forall a. Maybe a
Nothing

-- | Send a command to the process without reading its response.
write :: Handle -> Builder -> IO ()
write :: Handle -> Builder -> IO ()
write Handle {Maybe Handle
Handle
ProcessHandle
hMaybeErr :: Maybe Handle
hOut :: Handle
hIn :: Handle
process :: ProcessHandle
hMaybeErr :: Handle -> Maybe Handle
hOut :: Handle -> Handle
hIn :: Handle -> Handle
process :: Handle -> ProcessHandle
..} Builder
cmd =
  forall a. String -> IO a -> IO a
decorateIOError String
msg forall a b. (a -> b) -> a -> b
$ do
    Handle -> Builder -> IO ()
hPutBuilder Handle
hIn forall a b. (a -> b) -> a -> b
$ Builder
cmd forall a. Semigroup a => a -> a -> a
<> Builder
"\n"
    Handle -> IO ()
IO.hFlush Handle
hIn
  where
    msg :: String
msg = String
"sending command " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Builder -> ByteString
toLazyByteString Builder
cmd) forall a. [a] -> [a] -> [a]
++ String
" to the solver"

-- | Cleanup the process' resources, terminate it and wait for it to actually exit.
close :: Handle -> IO ()
close :: Handle -> IO ()
close Handle {Maybe Handle
Handle
ProcessHandle
hMaybeErr :: Maybe Handle
hOut :: Handle
hIn :: Handle
process :: ProcessHandle
hMaybeErr :: Handle -> Maybe Handle
hOut :: Handle -> Handle
hIn :: Handle -> Handle
process :: Handle -> ProcessHandle
..} = forall a. String -> IO a -> IO a
decorateIOError String
"closing the solver process" forall a b. (a -> b) -> a -> b
$ do
  (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle) -> IO ()
P.cleanupProcess (forall a. a -> Maybe a
Just Handle
hIn, forall a. a -> Maybe a
Just Handle
hOut, Maybe Handle
hMaybeErr, ProcessHandle
process)

-- | Create a solver process, use it to make a computation and close it.
with ::
  -- | The solver process' configuration.
  Config ->
  -- | The computation to run with the solver process
  (Handle -> IO a) ->
  IO a
with :: forall a. Config -> (Handle -> IO a) -> IO a
with Config
config = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
X.bracket (Config -> IO Handle
new Config
config) Handle -> IO ()
close

infixr 5 :<

pattern (:<) :: Char -> BS.ByteString -> BS.ByteString
pattern c $m:< :: forall {r}.
ByteString -> (Char -> ByteString -> r) -> ((# #) -> r) -> r
:< rest <- (BS.uncons -> Just (c, rest))

-- | Make the solver process into an SMT-LIB backend.
toBackend :: Handle -> Backend
toBackend :: Handle -> Backend
toBackend Handle
handle = (Builder -> IO ByteString) -> (Builder -> IO ()) -> Backend
Backend Builder -> IO ByteString
backendSend Builder -> IO ()
backendSend_
  where
    backendSend_ :: Builder -> IO ()
backendSend_ = Handle -> Builder -> IO ()
write Handle
handle
    backendSend :: Builder -> IO ByteString
backendSend Builder
cmd = do
      -- exceptions are decorated inside the body of 'write'
      Handle -> Builder -> IO ()
write Handle
handle Builder
cmd
      forall a. String -> IO a -> IO a
decorateIOError String
"reading solver's response" forall a b. (a -> b) -> a -> b
$
        Builder -> ByteString
toLazyByteString
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
0) forall a. Monoid a => a
mempty

    -- scanParen read lines from the handle's output channel until it has detected
    -- a complete s-expression, i.e. a well-parenthesized word that may contain
    -- strings, quoted symbols, and comments
    -- if we detect a ')' at depth 0 that is not enclosed in a string, a quoted
    -- symbol or a comment, we give up and return immediately
    -- see also the SMT-LIB standard v2.6
    -- https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf#part.2
    scanParen :: Int -> Builder -> BS.ByteString -> IO Builder
    scanParen :: Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc (Char
'(' :< ByteString
more) = Int -> Builder -> ByteString -> IO Builder
scanParen (Int
depth forall a. Num a => a -> a -> a
+ Int
1) Builder
acc ByteString
more
    scanParen Int
depth Builder
acc (Char
'"' :< ByteString
more) = do
      (Builder
acc', ByteString
more') <- Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
      Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc' ByteString
more'
    scanParen Int
depth Builder
acc (Char
'|' :< ByteString
more) = do
      (Builder
acc', ByteString
more') <- Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc ByteString
more
      Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc' ByteString
more'
    scanParen Int
depth Builder
acc (Char
';' :< ByteString
_) = forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth) Builder
acc
    scanParen Int
depth Builder
acc (Char
')' :< ByteString
more)
      | Int
depth forall a. Ord a => a -> a -> Bool
<= Int
1 = forall (m :: * -> *) a. Monad m => a -> m a
return Builder
acc
      | Bool
otherwise = Int -> Builder -> ByteString -> IO Builder
scanParen (Int
depth forall a. Num a => a -> a -> a
- Int
1) Builder
acc ByteString
more
    scanParen Int
depth Builder
acc (Char
_ :< ByteString
more) = Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth Builder
acc ByteString
more
    -- mempty case
    scanParen Int
0 Builder
acc ByteString
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Builder
acc
    scanParen Int
depth Builder
acc ByteString
_ = forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine (Int -> Builder -> ByteString -> IO Builder
scanParen Int
depth) Builder
acc

    string :: Builder -> BS.ByteString -> IO (Builder, BS.ByteString)
    string :: Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc (Char
'"' :< Char
'"' :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    string Builder
acc (Char
'"' :< ByteString
more) = forall (m :: * -> *) a. Monad m => a -> m a
return (Builder
acc, ByteString
more)
    string Builder
acc (Char
_ :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    -- mempty case
    string Builder
acc ByteString
_ = forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc

    quotedSymbol :: Builder -> BS.ByteString -> IO (Builder, BS.ByteString)
    quotedSymbol :: Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc (Char
'|' :< ByteString
more) = forall (m :: * -> *) a. Monad m => a -> m a
return (Builder
acc, ByteString
more)
    quotedSymbol Builder
acc (Char
_ :< ByteString
more) = Builder -> ByteString -> IO (Builder, ByteString)
string Builder
acc ByteString
more
    -- mempty case
    quotedSymbol Builder
acc ByteString
_ = forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO (Builder, ByteString)
quotedSymbol Builder
acc

    continueNextLine :: (Builder -> BS.ByteString -> IO a) -> Builder -> IO a
    continueNextLine :: forall a. (Builder -> ByteString -> IO a) -> Builder -> IO a
continueNextLine Builder -> ByteString -> IO a
f Builder
acc = do
      ByteString
next <-
        Handle -> IO ByteString
BS.hGetLine (Handle -> Handle
hOut Handle
handle) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`X.catch` \IOException
ex ->
          forall e a. Exception e => e -> IO a
X.throwIO
            ( IOException
ex
                { ioe_description :: String
ioe_description =
                    IOException -> String
ioe_description IOException
ex
                      forall a. [a] -> [a] -> [a]
++ String
": "
                      forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Builder -> ByteString
toLazyByteString Builder
acc)
                }
            )
      Builder -> ByteString -> IO a
f (Builder
acc forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
byteString ByteString
next) ByteString
next

decorateIOError :: String -> IO a -> IO a
decorateIOError :: forall a. String -> IO a -> IO a
decorateIOError String
contextDescription =
  forall e a. Exception e => (e -> IO a) -> IO a -> IO a
X.handle forall a b. (a -> b) -> a -> b
$ \IOException
ex ->
    forall e a. Exception e => e -> IO a
X.throwIO
      ( IOException
ex
          { ioe_description :: String
ioe_description =
              String
"[smtlib-backends-process] while "
                forall a. [a] -> [a] -> [a]
++ String
contextDescription
                forall a. [a] -> [a] -> [a]
++ String
": "
                forall a. [a] -> [a] -> [a]
++ IOException -> String
ioe_description IOException
ex
          }
      )