{-# LANGUAGE RecordWildCards #-}

-- original implementation:
-- <https://qiita.com/drken/items/cce6fc5c579051e64fab>

-- | A potentialized disjoint set union on a [group](https://en.wikipedia.org/wiki/Group_(mathematics\))
-- under a differential constraint system. Each vertex \(v\) is assigned a potential value \(p(v)\),
-- where representatives (`leader`) of each group have a potential of `mempty`, and other vertices have
-- potentials relative to their representative.
--
-- The group type is represented as a `Monoid` with a inverse operator, passed on `new`. This
-- approach avoids defining a separate typeclass for groups.
--
-- ==== Invariant
-- New monoids always come from the left: @new <> old@. The order is important for non-commutative
-- monoid implementations.
--
-- @since 1.1.0.0
module AtCoder.Extra.Pdsu
  ( -- * Pdsu
    Pdsu (nPdsu),

    -- * Constructors
    new,

    -- * Inspection
    leader,
    pot,
    diff,
    unsafeDiff,
    same,
    canMerge,

    -- * Merging
    merge,
    merge_,

    -- * Group information
    size,
    groups,

    -- * Reset
    clear,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | A potentialized disjoint set union on a [group](https://en.wikipedia.org/wiki/Group_(mathematics\))
-- under a differential constraint system. Each vertex \(v\) is assigned a potential value \(p(v)\),
--
-- ==== __Example__
-- Create a `Pdsu` with four vertices with potential type @Sum Int@. Use `negate` as the inverse
-- operator:
--
-- >>> import AtCoder.Extra.Pdsu qualified as Pdsu
-- >>> import Data.Semigroup (Sum (..))
-- >>> dsu <- Pdsu.new @_ @(Sum Int) 4 negate
--
-- The API is similar to @Dsu@, but with differential potential values:
--
-- >>> Pdsu.merge dsu 1 0 (Sum 1)  -- p(1) - p(0) := Sum 1
-- True
--
-- >>> Pdsu.merge_ dsu 2 0 (Sum 2) -- p(2) - p(0) := Sum 2
-- >>> Pdsu.leader dsu 0
-- 0
--
-- Potential values can be retrieved with `pot`:
--
-- >>> Pdsu.pot dsu 0
-- Sum {getSum = 0}
--
-- >>> Pdsu.pot dsu 1
-- Sum {getSum = 1}
--
-- >>> Pdsu.pot dsu 2
-- Sum {getSum = 2}
--
-- Difference of potentials in the same group can be retrieved with `diff`:
--
-- >>> Pdsu.diff dsu 2 1
-- Just (Sum {getSum = 1})
--
-- >>> Pdsu.diff dsu 2 3
-- Nothing
--
-- Retrieve group information with `groups`
--
-- >>> Pdsu.groups dsu
-- [[2,1,0],[3]]
--
-- @since 1.1.0.0
data Pdsu s a = Pdsu
  { -- | The number of vertices.
    forall s a. Pdsu s a -> Int
nPdsu :: {-# UNPACK #-} !Int,
    -- | Parent: non-positive, size: positive
    forall s a. Pdsu s a -> MVector s Int
parentOrSizePdsu :: !(VUM.MVector s Int),
    -- | Diffierencial potential of each vertex.
    forall s a. Pdsu s a -> MVector s a
potentialPdsu :: !(VUM.MVector s a),
    forall s a. Pdsu s a -> a -> a
invertPdsu :: !(a -> a)
  }

-- | \(O(n)\) Creates a new DSU under a differential constraint system.
--
-- @since 1.1.0.0
{-# INLINE new #-}
new ::
  forall m a.
  (PrimMonad m, Monoid a, VU.Unbox a) =>
  -- | The number of vertices
  Int ->
  -- | The inverse operator of the monoid
  (a -> a) ->
  -- | A DSU
  m (Pdsu (PrimState m) a)
new :: forall (m :: * -> *) a.
(PrimMonad m, Monoid a, Unbox a) =>
Int -> (a -> a) -> m (Pdsu (PrimState m) a)
new Int
n a -> a
f = Int
-> MVector (PrimState m) Int
-> MVector (PrimState m) a
-> (a -> a)
-> Pdsu (PrimState m) a
forall s a.
Int -> MVector s Int -> MVector s a -> (a -> a) -> Pdsu s a
Pdsu Int
n (MVector (PrimState m) Int
 -> MVector (PrimState m) a -> (a -> a) -> Pdsu (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (MVector (PrimState m) a -> (a -> a) -> Pdsu (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
n (-Int
1 {- size 1 -}) m (MVector (PrimState m) a -> (a -> a) -> Pdsu (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m ((a -> a) -> Pdsu (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
n (a
forall a. Monoid a => a
mempty :: a) m ((a -> a) -> Pdsu (PrimState m) a)
-> m (a -> a) -> m (Pdsu (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> a) -> m (a -> a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a -> a
f

-- | \(O(\alpha(n))\) Returns the representative of the connected component that contains the
-- vertex.
--
-- @since 1.1.0.0
{-# INLINE leader #-}
leader :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> m Int
leader :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu {Int
MVector (PrimState m) a
MVector (PrimState m) Int
a -> a
nPdsu :: forall s a. Pdsu s a -> Int
parentOrSizePdsu :: forall s a. Pdsu s a -> MVector s Int
potentialPdsu :: forall s a. Pdsu s a -> MVector s a
invertPdsu :: forall s a. Pdsu s a -> a -> a
nPdsu :: Int
parentOrSizePdsu :: MVector (PrimState m) Int
potentialPdsu :: MVector (PrimState m) a
invertPdsu :: a -> a
..} Int
v0 = Int -> m Int
inner Int
v0
  where
    !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkIndex String
"AtCoder.Extra.Pdsu.leader" Int
v0 Int
nPdsu
    inner :: Int -> m Int
inner Int
v = do
      Int
p <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizePdsu Int
v
      if {- size? -} Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
v
        else do
          -- NOTE(perf): Path compression.
          -- Handle the nodes closer to the root first and move them onto just under the root
          !Int
r <- Int -> m Int
inner Int
p
          Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            !a
pp <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
p
            -- Move `v` to just under the root:
            MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
parentOrSizePdsu Int
v {- root -} Int
r
            -- INVARIANT: new coming monoids always come from the left. And we're performing
            -- reverse folding.
            MVector (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) a
potentialPdsu (a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
pp) Int
v
          Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
r

-- | \(O(\alpha(n))\) Returns \(p(v)\), the potential value of vertex \(v\) relative to the
-- reprensetative of its group.
--
-- @since 1.1.0.0
{-# INLINE pot #-}
pot :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> m a
pot :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m a
pot dsu :: Pdsu (PrimState m) a
dsu@Pdsu {Int
MVector (PrimState m) a
MVector (PrimState m) Int
a -> a
nPdsu :: forall s a. Pdsu s a -> Int
parentOrSizePdsu :: forall s a. Pdsu s a -> MVector s Int
potentialPdsu :: forall s a. Pdsu s a -> MVector s a
invertPdsu :: forall s a. Pdsu s a -> a -> a
nPdsu :: Int
parentOrSizePdsu :: MVector (PrimState m) Int
potentialPdsu :: MVector (PrimState m) a
invertPdsu :: a -> a
..} Int
v1 = do
  -- Perform path compression
  Int
_ <- Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v1
  MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
v1
  where
    !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkIndex String
"AtCoder.Extra.Pdsu.pot" Int
v1 Int
nPdsu

-- | \(O(\alpha(n))\) Returns whether the vertices \(a\) and \(b\) are in the same connected
-- component.
--
-- @since 1.1.0.0
{-# INLINE same #-}
same :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> m Bool
same :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m Bool
same !Pdsu (PrimState m) a
dsu !Int
v1 !Int
v2 = Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool) -> m Int -> m (Int -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v1 m (Int -> Bool) -> m Int -> m Bool
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v2

-- TODO: call it unsafeDiff

-- | \(O(\alpha(n))\) Returns the potential of \(v_1\) relative to \(v_2\): \(p(v_1) \cdot p^{-1}(v_2)\)
-- if the two vertices belong to the same group. Returns `Nothing` when the two vertices are not
-- connected.
--
-- @since 1.1.0.0
{-# INLINE diff #-}
diff :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> m (Maybe a)
diff :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m (Maybe a)
diff !Pdsu (PrimState m) a
dsu !Int
v1 !Int
v2 = do
  Bool
b <- Pdsu (PrimState m) a -> Int -> Int -> m Bool
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m Bool
same Pdsu (PrimState m) a
dsu Int
v1 Int
v2
  if Bool
b
    then a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pdsu (PrimState m) a -> Int -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m a
unsafeDiff Pdsu (PrimState m) a
dsu Int
v1 Int
v2
    else Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing

-- | \(O(\alpha(n))\) Returns the potential of \(v_1\) relative to \(v_2\): \(p(v_1) \cdot p^{-1}(v_2)\)
-- if the two vertices belong to the same group. Returns meaningless value if the two vertices are
-- not connected.
--
-- @since 1.1.0.0
{-# INLINE unsafeDiff #-}
unsafeDiff :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> m a
unsafeDiff :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m a
unsafeDiff !Pdsu (PrimState m) a
dsu !Int
v1 !Int
v2 = do
  a
p1 <- Pdsu (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m a
pot Pdsu (PrimState m) a
dsu Int
v1
  a
p2 <- Pdsu (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m a
pot Pdsu (PrimState m) a
dsu Int
v2
  a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a
p1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Pdsu (PrimState m) a -> a -> a
forall s a. Pdsu s a -> a -> a
invertPdsu Pdsu (PrimState m) a
dsu a
p2

-- | \(O(\alpha(n))\) Merges \(v_1\) to \(v_2\) with differential (relative) potential
-- \(\mathrm{dp}\): \(p(v1) := \mathrm{dp} \cdot p(v2)\). Returns `True` if they're newly merged.
--
-- @since 1.1.0.0
{-# INLINE merge #-}
merge :: (HasCallStack, PrimMonad m, Monoid a, Ord a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
merge :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Ord a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
merge dsu :: Pdsu (PrimState m) a
dsu@Pdsu {Int
MVector (PrimState m) a
MVector (PrimState m) Int
a -> a
nPdsu :: forall s a. Pdsu s a -> Int
parentOrSizePdsu :: forall s a. Pdsu s a -> MVector s Int
potentialPdsu :: forall s a. Pdsu s a -> MVector s a
invertPdsu :: forall s a. Pdsu s a -> a -> a
nPdsu :: Int
parentOrSizePdsu :: MVector (PrimState m) Int
potentialPdsu :: MVector (PrimState m) a
invertPdsu :: a -> a
..} Int
v10 Int
v20 !a
dp0 = Int -> Int -> a -> m Bool
inner Int
v10 Int
v20 a
dp0
  where
    !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkIndex String
"AtCoder.Extra.Pdsu.merge" Int
v10 Int
nPdsu
    !()
_ = HasCallStack => String -> Int -> Int -> ()
String -> Int -> Int -> ()
ACIA.checkIndex String
"AtCoder.Extra.Pdsu.merge" Int
v20 Int
nPdsu
    inner :: Int -> Int -> a -> m Bool
inner Int
v1 Int
v2 !a
dp = do
      !Int
r1 <- Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v1
      !Int
r2 <- Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v2
      if Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2
        then Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
        else do
          -- NOTE(perf): Union by size (choose smaller one for root).
          -- Another, more proper optimization would be union by rank (depth).
          !a
size1 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
v1
          !a
size2 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
v2
          if a
size1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
size2
            then do
              -- Merge `r1` onto `r2`

              -- Update the size of `r1`
              !Int
negativeSize1 <- Int -> Int
forall a. Num a => a -> a
negate {- retrieve size -} (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizePdsu Int
r1
              !Int
negativeSize2 <- Int -> Int
forall a. Num a => a -> a
negate {- retrieve size -} (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
parentOrSizePdsu Int
r2
              MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
parentOrSizePdsu Int
r1 ({- size -} Int
negativeSize1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
negativeSize2)

              -- p(v1) becomes p'(v1) under r2 after merge. p(r1) becomes p'(r1).
              --     p'(v1) = dp <> p(v2)
              --     p'(v1) = p(v1) <> 'p(r1)
              -- Therefore,
              --     p'(r1) = p^{-1}(v1) <> dp <> p(v2)
              !a
p1 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
v1
              !a
p2 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
potentialPdsu Int
v2
              let !pr1' :: a
pr1' = a -> a
invertPdsu a
p1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
dp a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
p2

              -- Move `r1` to just under `r2`:
              MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
parentOrSizePdsu Int
r1 {- record new root -} Int
r2
              MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
potentialPdsu Int
r1 a
pr1'

              Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
            else do
              Int -> Int -> a -> m Bool
inner Int
v2 Int
v1 (a -> m Bool) -> a -> m Bool
forall a b. (a -> b) -> a -> b
$ a -> a
invertPdsu a
dp

-- | \(O(\alpha(n))\) `merge` with the return value discarded.
--
-- @since 1.1.0.0
{-# INLINE merge_ #-}
merge_ :: (HasCallStack, PrimMonad m, Monoid a, Ord a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> a -> m ()
merge_ :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Ord a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> a -> m ()
merge_ !Pdsu (PrimState m) a
dsu !Int
v1 !Int
v2 !a
dp = do
  Bool
_ <- Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Ord a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
merge Pdsu (PrimState m) a
dsu Int
v1 Int
v2 a
dp
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | \(O(\alpha(n))\) Returns `True` if the two vertices belong to different groups or they belong
-- to the same group under the condition \(p(v_1) = dp \cdot p(v_2)\). It's just a convenient
-- helper function.
--
-- @since 1.1.0.0
{-# INLINE canMerge #-}
canMerge :: (HasCallStack, PrimMonad m, Semigroup a, Eq a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
canMerge :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Eq a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> a -> m Bool
canMerge !Pdsu (PrimState m) a
dsu !Int
v1 !Int
v2 !a
dp = do
  Bool
b <- Pdsu (PrimState m) a -> Int -> Int -> m Bool
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> Int -> m Bool
same Pdsu (PrimState m) a
dsu Int
v1 Int
v2
  if Bool -> Bool
not Bool
b
    then Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
    else do
      !a
p1 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read (Pdsu (PrimState m) a -> MVector (PrimState m) a
forall s a. Pdsu s a -> MVector s a
potentialPdsu Pdsu (PrimState m) a
dsu) Int
v1
      !a
p2 <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read (Pdsu (PrimState m) a -> MVector (PrimState m) a
forall s a. Pdsu s a -> MVector s a
potentialPdsu Pdsu (PrimState m) a
dsu) Int
v2
      Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ a
p1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
dp a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
p2

-- | \(O(\alpha(n))\) Returns the number of vertices belonging to the same group.
--
-- @since 1.1.0.0
{-# INLINE size #-}
size :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => Pdsu (PrimState m) a -> Int -> m Int
size :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
size !Pdsu (PrimState m) a
dsu !Int
v = (Int -> Int
forall a. Num a => a -> a
negate <$>) (m Int -> m Int) -> (Int -> m Int) -> Int -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read (Pdsu (PrimState m) a -> MVector (PrimState m) Int
forall s a. Pdsu s a -> MVector s Int
parentOrSizePdsu Pdsu (PrimState m) a
dsu) (Int -> m Int) -> m Int -> m Int
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
v

-- | \(O(n)\) Divides the graph into connected components and returns the list of them.
--
-- @since 1.1.0.0
{-# INLINE groups #-}
groups :: (PrimMonad m, Semigroup a, VU.Unbox a) => Pdsu (PrimState m) a -> m (V.Vector (VU.Vector Int))
groups :: forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> m (Vector (Vector Int))
groups dsu :: Pdsu (PrimState m) a
dsu@Pdsu {Int
MVector (PrimState m) a
MVector (PrimState m) Int
a -> a
nPdsu :: forall s a. Pdsu s a -> Int
parentOrSizePdsu :: forall s a. Pdsu s a -> MVector s Int
potentialPdsu :: forall s a. Pdsu s a -> MVector s a
invertPdsu :: forall s a. Pdsu s a -> a -> a
nPdsu :: Int
parentOrSizePdsu :: MVector (PrimState m) Int
potentialPdsu :: MVector (PrimState m) a
invertPdsu :: a -> a
..} = do
  MVector (PrimState m) Int
groupSize <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nPdsu (Int
0 :: Int)
  Vector Int
leaders <- Int -> (Int -> m Int) -> m (Vector Int)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> (Int -> m a) -> m (Vector a)
VU.generateM Int
nPdsu ((Int -> m Int) -> m (Vector Int))
-> (Int -> m Int) -> m (Vector Int)
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Int
li <- Pdsu (PrimState m) a -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
Pdsu (PrimState m) a -> Int -> m Int
leader Pdsu (PrimState m) a
dsu Int
i
    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
groupSize (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
li
    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
li
  Vector (MVector (PrimState m) Int)
result <- do
    Vector Int
groupSize' <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
groupSize
    (Int -> m (MVector (PrimState m) Int))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector Int -> m (Vector (MVector (PrimState m) Int)))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert Vector Int
groupSize'
  Vector Int -> (Int -> Int -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
leaders ((Int -> Int -> m ()) -> m ()) -> (Int -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i Int
li -> do
    Int
i' <- Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
groupSize Int
li
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write (Vector (MVector (PrimState m) Int)
result Vector (MVector (PrimState m) Int)
-> Int -> MVector (PrimState m) Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
li) Int
i' Int
i
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
groupSize Int
li Int
i'
  (Vector Int -> Bool) -> Vector (Vector Int) -> Vector (Vector Int)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter (Bool -> Bool
not (Bool -> Bool) -> (Vector Int -> Bool) -> Vector Int -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> Bool
forall a. Unbox a => Vector a -> Bool
VU.null) (Vector (Vector Int) -> Vector (Vector Int))
-> m (Vector (Vector Int)) -> m (Vector (Vector Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MVector (PrimState m) Int -> m (Vector Int))
-> Vector (MVector (PrimState m) Int) -> m (Vector (Vector Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze Vector (MVector (PrimState m) Int)
result

-- | \(O(n)\) Clears the `Pdsu` to the initial state.
--
-- @since 1.1.0.0
{-# INLINE clear #-}
clear :: forall m a. (PrimMonad m, Monoid a, VU.Unbox a) => Pdsu (PrimState m) a -> m ()
clear :: forall (m :: * -> *) a.
(PrimMonad m, Monoid a, Unbox a) =>
Pdsu (PrimState m) a -> m ()
clear !Pdsu (PrimState m) a
dsu = do
  MVector (PrimState m) a -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> a -> m ()
VGM.set (Pdsu (PrimState m) a -> MVector (PrimState m) a
forall s a. Pdsu s a -> MVector s a
potentialPdsu Pdsu (PrimState m) a
dsu) (forall a. Monoid a => a
mempty @a)
  MVector (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> a -> m ()
VGM.set (Pdsu (PrimState m) a -> MVector (PrimState m) Int
forall s a. Pdsu s a -> MVector s Int
parentOrSizePdsu Pdsu (PrimState m) a
dsu) (-Int
1 {- size -})