{-# LANGUAGE RecordWildCards #-}

-- | Solves 2-SAT.
--
-- For variables \(x_0, x_1, \cdots, x_{N - 1}\) and clauses with form
--
-- - \((x_i = f) \lor (x_j = g)\)
--
-- it decides whether there is a truth assignment that satisfies all clauses.
--
-- ==== __Example__
-- >>> import AtCoder.TwoSat qualified as TS
-- >>> import Data.Bit (Bit(..))
-- >>> ts <- TS.new 1
-- >>> TS.addClause ts 0 False 0 False -- x_0 == False || x_0 == False
-- >>> TS.satisfiable ts
-- True
--
-- >>> TS.answer ts
-- [0]
--
-- @since 1.0.0.0
module AtCoder.TwoSat
  ( -- * TwoSat
    TwoSat (nTs),
    -- * Constructor
    new,
    -- * Clause building
    addClause,
    -- * Solvers
    satisfiable,
    answer,
    unsafeAnswer,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Scc qualified as ACISCC
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bit (Bit (..))
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | 2-SAT state.
--
-- @since 1.0.0.0
data TwoSat s = TwoSat
  { -- | The number of clauses the `TwoSat` can hold.
    --
    -- @since 1.0.0.0
    forall s. TwoSat s -> Int
nTs :: {-# UNPACK #-} !Int,
    forall s. TwoSat s -> MVector s Bit
answerTs :: !(VUM.MVector s Bit),
    forall s. TwoSat s -> SccGraph s
sccTs :: !(ACISCC.SccGraph s)
  }

-- | Creates a 2-SAT of \(n\) variables and \(0\) clauses.
--
-- ==== Constraints
-- - \(0 \leq n\)
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (TwoSat (PrimState m))
new :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (TwoSat (PrimState m))
new Int
nTs = do
  MVector (PrimState m) Bit
answerTs <- Int -> m (MVector (PrimState m) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
nTs
  SccGraph (PrimState m)
sccTs <- Int -> m (SccGraph (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (SccGraph (PrimState m))
ACISCC.new (Int -> m (SccGraph (PrimState m)))
-> Int -> m (SccGraph (PrimState m))
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
nTs
  TwoSat (PrimState m) -> m (TwoSat (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..}

-- | Adds a clause \((x_i = f) \lor (x_j = g)\).
--
-- ==== Constraints
-- - \(0 \leq i \lt n\)
-- - \(0 \leq j \lt n\)
--
-- ==== Complexity
-- - \(O(1)\) amortized.
--
-- @since 1.0.0.0
{-# INLINE addClause #-}
addClause :: (HasCallStack, PrimMonad m) => TwoSat (PrimState m) -> Int -> Bool -> Int -> Bool -> m ()
addClause :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
TwoSat (PrimState m) -> Int -> Bool -> Int -> Bool -> m ()
addClause TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: forall s. TwoSat s -> Int
answerTs :: forall s. TwoSat s -> MVector s Bit
sccTs :: forall s. TwoSat s -> SccGraph s
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..} Int
i Bool
f Int
j Bool
g = do
  let !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkVertex String
"AtCoder.TwoSat.addClause" Int
i Int
nTs
  let !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkVertex String
"AtCoder.TwoSat.addClause" Int
j Int
nTs
  SccGraph (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
ACISCC.addEdge SccGraph (PrimState m)
sccTs (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
f then Int
0 else Int
1) (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
g then Int
1 else Int
0)
  SccGraph (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
ACISCC.addEdge SccGraph (PrimState m)
sccTs (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
g then Int
0 else Int
1) (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
f then Int
1 else Int
0)

-- | If there is a truth assignment that satisfies all clauses, it returns `True`. Otherwise, it
-- returns `False`.
--
-- ==== Constraints
-- - You may call it multiple times.
--
-- ==== Complexity
-- - \(O(n + m)\), where \(m\) is the number of added clauses.
--
-- @since 1.0.0.0
{-# INLINE satisfiable #-}
satisfiable :: (PrimMonad m) => TwoSat (PrimState m) -> m Bool
satisfiable :: forall (m :: * -> *). PrimMonad m => TwoSat (PrimState m) -> m Bool
satisfiable TwoSat {Int
MVector (PrimState m) Bit
SccGraph (PrimState m)
nTs :: forall s. TwoSat s -> Int
answerTs :: forall s. TwoSat s -> MVector s Bit
sccTs :: forall s. TwoSat s -> SccGraph s
nTs :: Int
answerTs :: MVector (PrimState m) Bit
sccTs :: SccGraph (PrimState m)
..} = do
  (!Int
_, !Vector Int
ids) <- SccGraph (PrimState m) -> m (Int, Vector Int)
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
ACISCC.sccIds SccGraph (PrimState m)
sccTs
  let inner :: Int -> m Bool
inner Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nTs = Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
        | Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) = Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
        | Bool
otherwise = do
            MVector (PrimState m) Bit -> Int -> Bit -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Bit
answerTs Int
i (Bit -> m ()) -> (Bool -> Bit) -> Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Bit
Bit (Bool -> m ()) -> Bool -> m ()
forall a b. (a -> b) -> a -> b
$ Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Int
ids Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Int -> m Bool
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  Int -> m Bool
inner Int
0

-- | Returns a truth assignment that satisfies all clauses of the last call of `satisfiable`. If we
-- call it before calling `satisfiable` or when the last call of `satisfiable` returns `False`, it
-- returns the vector of length \(n\) with undefined elements.
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
{-# INLINE answer #-}
answer :: (PrimMonad m) => TwoSat (PrimState m) -> m (VU.Vector Bit)
answer :: forall (m :: * -> *).
PrimMonad m =>
TwoSat (PrimState m) -> m (Vector Bit)
answer = MVector (PrimState m) Bit -> m (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.freeze (MVector (PrimState m) Bit -> m (Vector Bit))
-> (TwoSat (PrimState m) -> MVector (PrimState m) Bit)
-> TwoSat (PrimState m)
-> m (Vector Bit)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TwoSat (PrimState m) -> MVector (PrimState m) Bit
forall s. TwoSat s -> MVector s Bit
answerTs

-- | `answer` without making copy.
--
-- ==== Complexity
-- - \(O(1)\)
--
-- @since 1.0.0.0
{-# INLINE unsafeAnswer #-}
unsafeAnswer :: (PrimMonad m) => TwoSat (PrimState m) -> m (VU.Vector Bit)
unsafeAnswer :: forall (m :: * -> *).
PrimMonad m =>
TwoSat (PrimState m) -> m (Vector Bit)
unsafeAnswer = MVector (PrimState m) Bit -> m (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze (MVector (PrimState m) Bit -> m (Vector Bit))
-> (TwoSat (PrimState m) -> MVector (PrimState m) Bit)
-> TwoSat (PrimState m)
-> m (Vector Bit)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TwoSat (PrimState m) -> MVector (PrimState m) Bit
forall s. TwoSat s -> MVector s Bit
answerTs