{-# LANGUAGE ForeignFunctionInterface #-}
module Picosat (
solve,
solveAll,
unsafeSolve,
unsafeSolveAll,
Picosat,
Solution(..),
evalScopedPicosat,
addBaseClauses,
withScopedClauses,
scopedAllSolutions,
scopedSolutionWithAssumptions
) where
import Control.Monad
import System.IO.Unsafe (unsafePerformIO)
import Foreign.Ptr
import Foreign.C.Types
import Control.Monad.Trans.State.Strict
import Control.Monad.IO.Class
import qualified Data.Set as S
foreign import ccall unsafe "picosat_init" picosat_init
:: IO (Picosat)
foreign import ccall unsafe "picosat_reset" picosat_reset
:: Picosat -> IO ()
foreign import ccall unsafe "picosat_add" picosat_add
:: Picosat -> CInt -> IO CInt
foreign import ccall unsafe "picosat_variables" picosat_variables
:: Picosat -> IO CInt
foreign import ccall unsafe "picosat_sat" picosat_sat
:: Picosat -> CInt -> IO CInt
foreign import ccall unsafe "picosat_deref" picosat_deref
:: Picosat -> CInt -> IO CInt
foreign import ccall unsafe "picosat_push" picosat_push
:: Picosat -> IO CInt
foreign import ccall unsafe "picosat_pop" picosat_pop
:: Picosat -> IO CInt
foreign import ccall unsafe "picosat_assume" picosat_assume
:: Picosat -> CInt -> IO ()
type Picosat = Ptr ()
withPicosat :: (Picosat -> IO a) -> IO a
withPicosat f = do
pico <- picosat_init
res <- f pico
picosat_reset pico
return res
unknown, satisfiable, unsatisfiable :: CInt
unknown = 0
satisfiable = 10
unsatisfiable = 20
data Solution = Solution [Int]
| Unsatisfiable
| Unknown deriving (Show, Eq, Ord)
addClause :: Picosat -> [Int] -> IO ()
addClause pico cl = do
_ <- mapM_ (picosat_add pico . fromIntegral) cl
_ <- picosat_add pico 0
return ()
addClauses :: Picosat -> [[Int]] -> IO ()
addClauses pico = mapM_ $ addClause pico
getSolution :: Picosat -> IO Solution
getSolution pico = do
vars <- picosat_variables pico
sol <- forM [1..vars] $ \i -> do
s <- picosat_deref pico i
return $ i * s
return $ Solution $ map fromIntegral sol
solution :: Picosat -> IO Solution
solution pico = do
res <- picosat_sat pico (-1)
case res of
a | a == unknown -> return Unknown
| a == unsatisfiable -> return Unsatisfiable
| a == satisfiable -> getSolution pico
| otherwise -> error "Picosat error."
solve :: [[Int]] -> IO Solution
solve cnf = do
withPicosat $ \ pico -> do
_ <- addClauses pico cnf
sol <- solution pico
return sol
solveAll :: [[Int]] -> IO [Solution]
solveAll cnf = do
evalScopedPicosat $ do
addBaseClauses cnf
scopedAllSolutions
data PicosatScoped = PicosatScoped { psPicosat :: Picosat,
psContextVars :: S.Set Int }
type PS a = StateT PicosatScoped IO a
evalScopedPicosat :: PS a -> IO a
evalScopedPicosat action =
withPicosat $ \ picosat -> do
evalStateT action $ PicosatScoped picosat S.empty
addBaseClauses :: [[Int]] -> PS ()
addBaseClauses clauses = do
pico <- gets psPicosat
liftIO $ addClauses pico clauses
withScopedClauses :: [[Int]] -> PS a -> PS a
withScopedClauses clauses action = do
pico <- gets psPicosat
withScope $ do
liftIO $ addClauses pico clauses
action
withScope :: PS a -> PS a
withScope action = do
pico <- gets psPicosat
ctx <- liftIO $ picosat_push pico
addContextVariable $ fromIntegral ctx
res <- action
_ <- liftIO $ picosat_pop pico
return res
addContextVariable :: Int -> PS ()
addContextVariable var = modify add
where add s = s { psContextVars = S.insert var $ psContextVars s}
scopedSolution :: PS Solution
scopedSolution = do
pico <- gets psPicosat
sol <- liftIO $ solution pico
case sol of
Solution ys -> do
ctxvars <- gets psContextVars
return $ Solution $
filter (\l -> S.notMember (abs l) ctxvars) $ ys
x ->
return x
scopedAllSolutions :: PS [Solution]
scopedAllSolutions = do
let recur solutions = do
pico <- gets psPicosat
sol <- scopedSolution
case sol of
Solution ys -> do
let negsol = map negate ys
liftIO $ addClause pico negsol
recur (sol : solutions)
_ ->
return $ reverse solutions
withScope $ recur []
scopedSolutionWithAssumptions :: [Int] -> PS Solution
scopedSolutionWithAssumptions assumptions = do
pico <- gets psPicosat
liftIO $ mapM_ (picosat_assume pico . fromIntegral) assumptions
scopedSolution
{-# NOINLINE unsafeSolve #-}
unsafeSolve :: [[Int]] -> Solution
unsafeSolve = unsafePerformIO . solve
{-# NOINLINE unsafeSolveAll #-}
unsafeSolveAll :: [[Int]] -> [Solution]
unsafeSolveAll = unsafePerformIO . solveAll