{-# LANGUAGE RecordWildCards #-}

-- | A disjoint set union, also known as a Union-Find tree. It processes the following queries in
-- amortized \(O(\alpha(n))\) time.
--
-- - Edge addition (`merge`)
-- - Deciding whether given two vertices are in the same connected component (`same`)
--
-- Each connected component internally has a representative vertex (`leader`). When two connected
-- components are merged by edge addition (`merge`), one of the two representatives of these
-- connected components becomes the representative (`leader`) of the new connected component.
--
-- ==== __Example__
-- Create a `Dsu` with four vertices:
--
-- >>> import AtCoder.Dsu qualified as Dsu
-- >>> dsu <- Dsu.new 4   -- 0 1 2 3
-- >>> Dsu.nDsu dsu
-- 4
--
-- Merge some vertices into the same group:
--
-- >>> Dsu.merge dsu 0 1  -- 0=1 2 3
-- 0
--
-- >>> Dsu.merge_ dsu 1 2 -- 0=1=2 3
--
-- `leader` returns the internal representative vertex of the connected components:
--
-- >>> Dsu.leader dsu 2
-- 0
--
-- Retrieve group information:
--
-- >>> Dsu.same dsu 0 2
-- True
--
-- >>> Dsu.size dsu 0
-- 3
--
-- >>> Dsu.groups dsu
-- [[2,1,0],[3]]
--
-- @since 1.0.0.0
module AtCoder.Dsu
  ( -- * Disjoint set union
    Dsu (nDsu),

    -- * Constructor
    new,

    -- * Merging
    merge,
    merge_,

    -- * Leader
    leader,

    -- * Component information
    same,
    size,
    groups,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | A disjoint set union. Akso known as Union-Find tree.
--
-- @since 1.0.0.0
data Dsu s = Dsu
  { -- | The number of nodes.
    --
    -- @since 1.0.0.0
    forall s. Dsu s -> Int
nDsu :: {-# UNPACK #-} !Int,
    -- | For root (leader) nodes it stores their size as a negative number. For child nodes it
    -- stores their parent node index.
    forall s. Dsu s -> MVector s Int
parentOrSizeDsu :: !(VUM.MVector s Int)
  }

-- | Creates an undirected graph with \(n\) vertices and \(0\) edges.
--
-- ==== Constraints
-- - \(0 \le n\)
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (Dsu (PrimState m))
new :: forall (m :: * -> *). PrimMonad m => Int -> m (Dsu (PrimState m))
new Int
nDsu
  | Int
nDsu Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
      MVector (PrimState m) Int
parentOrSizeDsu <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nDsu (-Int
1)
      Dsu (PrimState m) -> m (Dsu (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dsu {Int
MVector (PrimState m) Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..}
  | Bool
otherwise = [Char] -> m (Dsu (PrimState m))
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (Dsu (PrimState m)))
-> [Char] -> m (Dsu (PrimState m))
forall a b. (a -> b) -> a -> b
$ [Char]
"new: given negative size (`" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nDsu [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"`)"

-- | Adds an edge \((a, b)\). If the vertices \(a\) and \(b\) are in the same connected component, it
-- returns the representative (`leader`) of this connected component. Otherwise, it returns the
-- representative of the new connected component.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.0.0.0
{-# INLINE merge #-}
merge :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m Int
merge :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
merge dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a Int
b = do
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.merge" Int
a Int
nDsu
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.merge" Int
b Int
nDsu
  Int
x <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
  Int
y <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
b
  if Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y
    then do
      Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
    else do
      Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
x
      Int
py <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
y
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (-Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< -Int
py) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) Int
parentOrSizeDsu Int
x Int
y
      Int
sizeY <- MVector (PrimState m) Int -> Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m a
VGM.exchange MVector (PrimState m) Int
parentOrSizeDsu Int
y Int
x
      MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
parentOrSizeDsu (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sizeY) Int
x
      Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x

-- | `merge` with the return value discarded.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.0.0.0
{-# INLINE merge_ #-}
merge_ :: (PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m ()
merge_ :: forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> Int -> Int -> m ()
merge_ Dsu (PrimState m)
dsu Int
a Int
b = do
  Int
_ <- Dsu (PrimState m) -> Int -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
merge Dsu (PrimState m)
dsu Int
a Int
b
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Returns whether the vertices \(a\) and \(b\) are in the same connected component.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.0.0.0
{-# INLINE same #-}
same :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> Int -> m Bool
same :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Bool
same dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a Int
b = do
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.same" Int
a Int
nDsu
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.same" Int
b Int
nDsu
  Int
la <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
  Int
lb <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
b
  Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Int
la Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lb

-- | Returns the representative of the connected component that contains the vertex \(a\).
--
-- ==== Constraints
-- - \(0 \leq a \lt n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.0.0.0
{-# INLINE leader #-}
leader :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> m Int
leader :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a = do
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.leader" Int
a Int
nDsu
  Int
pa <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
a
  if Int
pa Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
    then Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
a
    else do
      Int
lpa <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
pa
      MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
parentOrSizeDsu Int
a Int
lpa
      Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
lpa

-- | Returns the size of the connected component that contains the vertex \(a\).
--
-- ==== Constraints
-- -  \(0 \leq a < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\)
--
-- @since 1.0.0.0
{-# INLINE size #-}
size :: (HasCallStack, PrimMonad m) => Dsu (PrimState m) -> Int -> m Int
size :: forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
size dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} Int
a = do
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkVertex [Char]
"AtCoder.Dsu.size" Int
a Int
nDsu
  Int
la <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
a
  Int
sizeLa <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizeDsu Int
la
  Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (-Int
sizeLa)

-- | Divides the graph into connected components and returns the vector of them.
--
-- More precisely, it returns a vector of the "vector of the vertices in a connected component".
-- Both of the orders of the connected components and the vertices are undefined.
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
{-# INLINE groups #-}
groups :: (PrimMonad m) => Dsu (PrimState m) -> m (V.Vector (VU.Vector Int))
groups :: forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> m (Vector (Vector Int))
groups dsu :: Dsu (PrimState m)
dsu@Dsu {Int
MVector (PrimState m) Int
nDsu :: forall s. Dsu s -> Int
parentOrSizeDsu :: forall s. Dsu s -> MVector s Int
nDsu :: Int
parentOrSizeDsu :: MVector (PrimState m) Int
..} = do
  MVector (PrimState m) Int
groupSize <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nDsu (Int
0 :: Int)
  Vector Int
leaders <- Int -> (Int -> m Int) -> m (Vector Int)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> (Int -> m a) -> m (Vector a)
VU.generateM Int
nDsu ((Int -> m Int) -> m (Vector Int))
-> (Int -> m Int) -> m (Vector Int)
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Int
li <- Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
leader Dsu (PrimState m)
dsu Int
i
    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
groupSize (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
li
    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
li
  Vector (MVector (PrimState m) Int)
result <- do
    Vector Int
groupSize' <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
groupSize
    (Int -> m (MVector (PrimState m) Int))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector Int -> m (Vector (MVector (PrimState m) Int)))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert Vector Int
groupSize'
  Vector Int -> (Int -> Int -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
leaders ((Int -> Int -> m ()) -> m ()) -> (Int -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i Int
li -> do
    Int
i' <- Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
groupSize Int
li
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write (Vector (MVector (PrimState m) Int)
result Vector (MVector (PrimState m) Int)
-> Int -> MVector (PrimState m) Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
li) Int
i' Int
i
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
groupSize Int
li Int
i'
  (Vector Int -> Bool) -> Vector (Vector Int) -> Vector (Vector Int)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter (Bool -> Bool
not (Bool -> Bool) -> (Vector Int -> Bool) -> Vector Int -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> Bool
forall a. Unbox a => Vector a -> Bool
VU.null) (Vector (Vector Int) -> Vector (Vector Int))
-> m (Vector (Vector Int)) -> m (Vector (Vector Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MVector (PrimState m) Int -> m (Vector Int))
-> Vector (MVector (PrimState m) Int) -> m (Vector (Vector Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze Vector (MVector (PrimState m) Int)
result