{-# LANGUAGE RecordWildCards #-}
module AtCoder.TwoSat
(
TwoSat (nTs),
new,
addClause,
satisfiable,
answer,
unsafeAnswer,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Scc qualified as ACISCC
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bit (Bit (..))
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
data TwoSat s = TwoSat
{
forall s. TwoSat s -> Int
nTs :: {-# UNPACK #-} !Int,
forall s. TwoSat s -> MVector s Bit
answerTs :: !(VUM.MVector s Bit),
forall s. TwoSat s -> SccGraph s
sccTs :: !(ACISCC.SccGraph s)
}
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (TwoSat (PrimState m))
new :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (TwoSat (PrimState m))
new Int
nTs = do
MVector (PrimState m) Bit
answerTs <- Int -> m (MVector (PrimState m) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
nTs
SccGraph (PrimState m)
sccTs <- Int -> m (SccGraph (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (SccGraph (PrimState m))
ACISCC.new (Int -> m (SccGraph (PrimState m)))
-> Int -> m (SccGraph (PrimState m))
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
nTs
TwoSat (PrimState m) -> m (TwoSat (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..}
{-# INLINE addClause #-}
addClause :: (HasCallStack, PrimMonad m) => TwoSat (PrimState m) -> Int -> Bool -> Int -> Bool -> m ()
addClause :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
TwoSat (PrimState m) -> Int -> Bool -> Int -> Bool -> m ()
addClause TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: forall s. TwoSat s -> Int
answerTs :: forall s. TwoSat s -> MVector s Bit
sccTs :: forall s. TwoSat s -> SccGraph s
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..} Int
i Bool
f Int
j Bool
g = do
let !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkVertex String
"AtCoder.TwoSat.addClause" Int
i Int
nTs
let !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkVertex String
"AtCoder.TwoSat.addClause" Int
j Int
nTs
SccGraph (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
ACISCC.addEdge SccGraph (PrimState m)
sccTs (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
f then Int
0 else Int
1) (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
g then Int
1 else Int
0)
SccGraph (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
ACISCC.addEdge SccGraph (PrimState m)
sccTs (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
g then Int
0 else Int
1) (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
f then Int
1 else Int
0)
{-# INLINE satisfiable #-}
satisfiable :: (PrimMonad m) => TwoSat (PrimState m) -> m Bool
satisfiable :: forall (m :: * -> *). PrimMonad m => TwoSat (PrimState m) -> m Bool
satisfiable TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: forall s. TwoSat s -> Int
answerTs :: forall s. TwoSat s -> MVector s Bit
sccTs :: forall s. TwoSat s -> SccGraph s
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..} = do
(!Int
_, !Vector Int
ids) <- SccGraph (PrimState m) -> m (Int, Vector Int)
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
ACISCC.sccIds SccGraph (PrimState m)
sccTs
let inner :: Int -> m Bool
inner Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nTs = Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
| Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) = Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise = do
MVector (PrimState m) Bit -> Int -> Bit -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Bit
answerTs Int
i (Bit -> m ()) -> (Bool -> Bit) -> Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Bit
Bit (Bool -> m ()) -> Bool -> m ()
forall a b. (a -> b) -> a -> b
$ Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int -> m Bool
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int -> m Bool
inner Int
0
{-# INLINE answer #-}
answer :: (PrimMonad m) => TwoSat (PrimState m) -> m (VU.Vector Bit)
answer :: forall (m :: * -> *).
PrimMonad m =>
TwoSat (PrimState m) -> m (Vector Bit)
answer = MVector (PrimState m) Bit -> m (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.freeze (MVector (PrimState m) Bit -> m (Vector Bit))
-> (TwoSat (PrimState m) -> MVector (PrimState m) Bit)
-> TwoSat (PrimState m)
-> m (Vector Bit)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TwoSat (PrimState m) -> MVector (PrimState m) Bit
forall s. TwoSat s -> MVector s Bit
answerTs
{-# INLINE unsafeAnswer #-}
unsafeAnswer :: (PrimMonad m) => TwoSat (PrimState m) -> m (VU.Vector Bit)
unsafeAnswer :: forall (m :: * -> *).
PrimMonad m =>
TwoSat (PrimState m) -> m (Vector Bit)
unsafeAnswer = MVector (PrimState m) Bit -> m (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze (MVector (PrimState m) Bit -> m (Vector Bit))
-> (TwoSat (PrimState m) -> MVector (PrimState m) Bit)
-> TwoSat (PrimState m)
-> m (Vector Bit)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TwoSat (PrimState m) -> MVector (PrimState m) Bit
forall s. TwoSat s -> MVector s Bit
answerTs