{-# LANGUAGE RecordWildCards #-}
module AtCoder.Dsu
(
Dsu (nDsu),
new,
merge,
merge_,
leader,
same,
size,
groups,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
data Dsu s = Dsu
{
forall s. Dsu s -> Int
nDsu :: {-# UNPACK #-} !Int,
forall s. Dsu s -> MVector s Int
parentOrSizeDsu :: !(VUM.MVector s Int)
}
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (Dsu (PrimState m))
new :: forall (m :: * -> *). PrimMonad m => Int -> m (Dsu (PrimState m))
new Int
nDsu
| Int
nDsu Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
MVector (PrimState m) Int
parentOrSizeDsu <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nDsu (-Int
1)
Dsu (PrimState m) -> m (Dsu (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dsu {Int
MVector (PrimState m) Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..}
| Bool
otherwise = [Char] -> m (Dsu (PrimState m))
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (Dsu (PrimState m)))
-> [Char] -> m (Dsu (PrimState m))
forall a b. (a -> b) -> a -> b
$ [Char]
"new: given negative size (`" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nDsu [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"`)"
{-# INLINE merge #-}
merge :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m Int
merge :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
merge dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a Int
b = do
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.merge" Int
a Int
nDsu
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.merge" Int
b Int
nDsu
Int
x <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
Int
y <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
b
if Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y
then do
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
else do
Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
x
Int
py <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
y
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (-Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< -Int
py) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) Int
parentOrSizeDsu Int
x Int
y
Int
sizeY <- MVector (PrimState m) Int -> Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m a
VGM.exchange MVector (PrimState m) Int
parentOrSizeDsu Int
y Int
x
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
parentOrSizeDsu (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sizeY) Int
x
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
{-# INLINE merge_ #-}
merge_ :: (PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m ()
merge_ :: forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> Int -> Int -> m ()
merge_ Dsu (PrimState m)
dsu Int
a Int
b = do
Int
_ <- Dsu (PrimState m) -> Int -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
merge Dsu (PrimState m)
dsu Int
a Int
b
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE same #-}
same :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m Bool
same :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Bool
same dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a Int
b = do
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.same" Int
a Int
nDsu
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.same" Int
b Int
nDsu
Int
la <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
Int
lb <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
b
Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Int
la Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lb
{-# INLINE leader #-}
leader :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> m Int
leader :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a = do
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.leader" Int
a Int
nDsu
Int
pa <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
a
if Int
pa Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
then Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
a
else do
Int
lpa <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
pa
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
parentOrSizeDsu Int
a Int
lpa
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
lpa
{-# INLINE size #-}
size :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> m Int
size :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
size dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a = do
let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.size" Int
a Int
nDsu
Int
la <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
Int
sizeLa <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
la
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (-Int
sizeLa)
{-# INLINE groups #-}
groups :: (PrimMonad m) => Dsu (PrimState m) -> m (V.Vector (VU.Vector Int))
groups :: forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> m (Vector (Vector Int))
groups dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} = do
MVector (PrimState m) Int
groupSize <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nDsu (Int
0 :: Int)
Vector Int
leaders <- Int -> (Int -> m Int) -> m (Vector Int)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> (Int -> m a) -> m (Vector a)
VU.generateM Int
nDsu ((Int -> m Int) -> m (Vector Int))
-> (Int -> m Int) -> m (Vector Int)
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Int
li <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
i
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
groupSize (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
li
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
li
Vector (MVector (PrimState m) Int)
result <- do
Vector Int
groupSize' <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
groupSize
(Int -> m (MVector (PrimState m) Int))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector Int -> m (Vector (MVector (PrimState m) Int)))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert Vector Int
groupSize'
Vector Int -> (Int -> Int -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
leaders ((Int -> Int -> m ()) -> m ()) -> (Int -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i Int
li -> do
Int
i' <- Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
groupSize Int
li
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write (Vector (MVector (PrimState m) Int)
result Vector (MVector (PrimState m) Int)
-> Int -> MVector (PrimState m) Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
li) Int
i' Int
i
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
groupSize Int
li Int
i'
(Vector Int -> Bool) -> Vector (Vector Int) -> Vector (Vector Int)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter (Bool -> Bool
not (Bool -> Bool) -> (Vector Int -> Bool) -> Vector Int -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> Bool
forall a. Unbox a => Vector a -> Bool
VU.null) (Vector (Vector Int) -> Vector (Vector Int))
-> m (Vector (Vector Int)) -> m (Vector (Vector Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MVector (PrimState m) Int -> m (Vector Int))
-> Vector (MVector (PrimState m) Int) -> m (Vector (Vector Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze Vector (MVector (PrimState m) Int)
result