{-# LANGUAGE RecordWildCards #-}

-- | Immutable Compresed Sparse Row. It is re-exported from the @AtCoder.Extra.Graph@ module with
-- additional functionalities.
--
-- ==== __Example__
-- Create a `Csr` without edge weights using `build'` and retrieve the edges with `adj`:
--
-- >>> import AtCoder.Internal.Csr qualified as C
-- >>> let csr = build' 3 $ VU.fromList @(Int, Int) [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3)]
-- >>> csr `C.adj` 0
-- [1,2,3]
--
-- >>> csr `C.adj` 1
-- [2]
--
-- >>> csr `C.adj` 2
-- [3]
--
-- Create a `Csr` with edge weights using `build` and retrieve the edges with `adjW`:
--
-- >>> import AtCoder.Internal.Csr qualified as C
-- >>> let csr = build 3 $ VU.fromList @(Int, Int, Int) [(0, 1, 101), (0, 2, 102), (0, 3, 103), (1, 2, 112), (2, 3, 123)]
-- >>> csr `C.adjW` 0
-- [(1,101),(2,102),(3,103)]
--
-- >>> csr `C.adjW` 1
-- [(2,112)]
--
-- >>> csr `C.adjW` 2
-- [(3,123)]
--
-- @since 1.0.0.0
module AtCoder.Internal.Csr
  ( -- * Compressed sparse row
    Csr (..),

    -- * Constructor
    build,
    build',
    build1,

    -- * Accessors
    adj,
    adjW,
    eAdj,
  )
where

import Control.Monad.ST (runST)
import Data.Foldable (for_)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | Comperssed Sparse Row representation of a graph.
--
-- @since 1.0.0.0
data Csr w = Csr
  { -- | The number of vertices.
    --
    -- @since 1.1.0.0
    forall w. Csr w -> Int
nCsr :: {-# UNPACK #-} !Int,
    -- | The number of edges.
    --
    -- @since 1.1.0.0
    forall w. Csr w -> Int
mCsr :: {-# UNPACK #-} !Int,
    -- | Starting indices.
    --
    -- @since 1.1.0.0
    forall w. Csr w -> Vector Int
startCsr :: !(VU.Vector Int),
    -- | Adjacent vertices.
    --
    -- @since 1.1.0.0
    forall w. Csr w -> Vector Int
adjCsr :: !(VU.Vector Int),
    -- | Edge weights.
    --
    -- @since 1.1.0.0
    forall w. Csr w -> Vector w
wCsr :: !(VU.Vector w)
  }
  deriving
    ( -- | @since 1.0.0.0
      Csr w -> Csr w -> Bool
(Csr w -> Csr w -> Bool) -> (Csr w -> Csr w -> Bool) -> Eq (Csr w)
forall w. (Unbox w, Eq w) => Csr w -> Csr w -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall w. (Unbox w, Eq w) => Csr w -> Csr w -> Bool
== :: Csr w -> Csr w -> Bool
$c/= :: forall w. (Unbox w, Eq w) => Csr w -> Csr w -> Bool
/= :: Csr w -> Csr w -> Bool
Eq,
      -- | @since 1.0.0.0
      Int -> Csr w -> ShowS
[Csr w] -> ShowS
Csr w -> String
(Int -> Csr w -> ShowS)
-> (Csr w -> String) -> ([Csr w] -> ShowS) -> Show (Csr w)
forall w. (Show w, Unbox w) => Int -> Csr w -> ShowS
forall w. (Show w, Unbox w) => [Csr w] -> ShowS
forall w. (Show w, Unbox w) => Csr w -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall w. (Show w, Unbox w) => Int -> Csr w -> ShowS
showsPrec :: Int -> Csr w -> ShowS
$cshow :: forall w. (Show w, Unbox w) => Csr w -> String
show :: Csr w -> String
$cshowList :: forall w. (Show w, Unbox w) => [Csr w] -> ShowS
showList :: [Csr w] -> ShowS
Show
    )

-- | \(O(n + m)\) Creates a `Csr`.
--
-- @since 1.0.0.0
{-# INLINE build #-}
build :: (HasCallStack, VU.Unbox w) => Int -> VU.Vector (Int, Int, w) -> Csr w
build :: forall w.
(HasCallStack, Unbox w) =>
Int -> Vector (Int, Int, w) -> Csr w
build Int
nCsr Vector (Int, Int, w)
edges = (forall s. ST s (Csr w)) -> Csr w
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Csr w)) -> Csr w)
-> (forall s. ST s (Csr w)) -> Csr w
forall a b. (a -> b) -> a -> b
$ do
  let mCsr :: Int
mCsr = Vector (Int, Int, w) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, w)
edges
  MVector s Int
start <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
nCsr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
0 :: Int)

  let (!Vector Int
froms, !Vector Int
_, !Vector w
_) = Vector (Int, Int, w) -> (Vector Int, Vector Int, Vector w)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector (a, b, c) -> (Vector a, Vector b, Vector c)
VU.unzip3 Vector (Int, Int, w)
edges
  Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ Vector Int
froms ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
from -> do
    MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
start (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
from Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
1 .. Int
nCsr] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Int
prev <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s Int
MVector (PrimState (ST s)) Int
start (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
start (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
prev) Int
i

  MVector s Int
edgeAdj <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector (Int, Int, w) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, w)
edges)
  MVector s w
edgeW <- Int -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector (Int, Int, w) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, w)
edges)
  MVector s Int
counter <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
nCsr
  MVector (PrimState (ST s)) Int
-> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
VUM.unsafeCopy MVector s Int
MVector (PrimState (ST s)) Int
counter (MVector (PrimState (ST s)) Int -> ST s ())
-> MVector (PrimState (ST s)) Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ MVector s Int -> MVector s Int
forall a s. Unbox a => MVector s a -> MVector s a
VUM.init MVector s Int
start
  Vector (Int, Int, w) -> ((Int, Int, w) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ Vector (Int, Int, w)
edges (((Int, Int, w) -> ST s ()) -> ST s ())
-> ((Int, Int, w) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(!Int
from, !Int
to, !w
w) -> do
    Int
c <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s Int
MVector (PrimState (ST s)) Int
counter Int
from
    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
edgeAdj Int
c Int
to
    MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s w
MVector (PrimState (ST s)) w
edgeW Int
c w
w
    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
counter Int
from (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

  Vector Int
startCsr <- MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
start
  Vector Int
adjCsr <- MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
edgeAdj
  Vector w
wCsr <- MVector (PrimState (ST s)) w -> ST s (Vector w)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s w
MVector (PrimState (ST s)) w
edgeW
  Csr w -> ST s (Csr w)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Csr {Int
Vector w
Vector Int
nCsr :: Int
mCsr :: Int
startCsr :: Vector Int
adjCsr :: Vector Int
wCsr :: Vector w
nCsr :: Int
mCsr :: Int
startCsr :: Vector Int
adjCsr :: Vector Int
wCsr :: Vector w
..}

-- | \(O(n + m)\) Creates a `Csr` with no edge weight.
--
-- @since 1.0.0.0
{-# INLINE build' #-}
build' :: (HasCallStack) => Int -> VU.Vector (Int, Int) -> Csr ()
build' :: HasCallStack => Int -> Vector (Int, Int) -> Csr ()
build' Int
n Vector (Int, Int)
edges = Int -> Vector (Int, Int, ()) -> Csr ()
forall w.
(HasCallStack, Unbox w) =>
Int -> Vector (Int, Int, w) -> Csr w
build Int
n (Vector (Int, Int, ()) -> Csr ())
-> Vector (Int, Int, ()) -> Csr ()
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int -> Vector () -> Vector (Int, Int, ())
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector a -> Vector b -> Vector c -> Vector (a, b, c)
VU.zip3 Vector Int
us Vector Int
vs (Int -> () -> Vector ()
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
us) ())
  where
    (!Vector Int
us, !Vector Int
vs) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
edges

-- | \(O(n + m)\) Creates a `Csr` with @1@ as edge weights.
--
-- @since 1.1.0.0
{-# INLINE build1 #-}
build1 :: (HasCallStack) => Int -> VU.Vector (Int, Int) -> Csr Int
build1 :: HasCallStack => Int -> Vector (Int, Int) -> Csr Int
build1 Int
n Vector (Int, Int)
edges = Int -> Vector (Int, Int, Int) -> Csr Int
forall w.
(HasCallStack, Unbox w) =>
Int -> Vector (Int, Int, w) -> Csr w
build Int
n (Vector (Int, Int, Int) -> Csr Int)
-> Vector (Int, Int, Int) -> Csr Int
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int -> Vector Int -> Vector (Int, Int, Int)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector a -> Vector b -> Vector c -> Vector (a, b, c)
VU.zip3 Vector Int
us Vector Int
vs (Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
us) (Int
1 :: Int))
 where
    (!Vector Int
us, !Vector Int
vs) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
edges

-- | \(O(1)\) Returns the adjacent vertices.
--
-- @since 1.0.0.0
{-# INLINE adj #-}
adj :: (HasCallStack) => Csr w -> Int -> VU.Vector Int
adj :: forall w. HasCallStack => Csr w -> Int -> Vector Int
adj Csr {Int
Vector w
Vector Int
nCsr :: forall w. Csr w -> Int
mCsr :: forall w. Csr w -> Int
startCsr :: forall w. Csr w -> Vector Int
adjCsr :: forall w. Csr w -> Vector Int
wCsr :: forall w. Csr w -> Vector w
nCsr :: Int
mCsr :: Int
startCsr :: Vector Int
adjCsr :: Vector Int
wCsr :: Vector w
..} Int
i =
  let il :: Int
il = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i
      ir :: Int
ir = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
   in Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.slice Int
il (Int
ir Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
il) Vector Int
adjCsr

-- | \(O(1)\) Returns the adjacent vertices with weights.
--
-- @since 1.0.0.0
{-# INLINE adjW #-}
adjW :: (HasCallStack, VU.Unbox w) => Csr w -> Int -> VU.Vector (Int, w)
adjW :: forall w.
(HasCallStack, Unbox w) =>
Csr w -> Int -> Vector (Int, w)
adjW Csr {Int
Vector w
Vector Int
nCsr :: forall w. Csr w -> Int
mCsr :: forall w. Csr w -> Int
startCsr :: forall w. Csr w -> Vector Int
adjCsr :: forall w. Csr w -> Vector Int
wCsr :: forall w. Csr w -> Vector w
nCsr :: Int
mCsr :: Int
startCsr :: Vector Int
adjCsr :: Vector Int
wCsr :: Vector w
..} Int
i =
  let il :: Int
il = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i
      ir :: Int
ir = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
   in Vector Int -> Vector w -> Vector (Int, w)
forall a b.
(Unbox a, Unbox b) =>
Vector a -> Vector b -> Vector (a, b)
VU.zip (Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.slice Int
il (Int
ir Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
il) Vector Int
adjCsr) (Int -> Int -> Vector w -> Vector w
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.slice Int
il (Int
ir Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
il) Vector w
wCsr)

-- | \(O(n)\) Returns a vector of @(edgeId, adjacentVertex)@.
--
-- @since 1.0.0.0
{-# INLINE eAdj #-}
eAdj :: (HasCallStack) => Csr w -> Int -> VU.Vector (Int, Int)
eAdj :: forall w. HasCallStack => Csr w -> Int -> Vector (Int, Int)
eAdj Csr {Int
Vector w
Vector Int
nCsr :: forall w. Csr w -> Int
mCsr :: forall w. Csr w -> Int
startCsr :: forall w. Csr w -> Vector Int
adjCsr :: forall w. Csr w -> Vector Int
wCsr :: forall w. Csr w -> Vector w
nCsr :: Int
mCsr :: Int
startCsr :: Vector Int
adjCsr :: Vector Int
wCsr :: Vector w
..} Int
i =
  let il :: Int
il = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i
      ir :: Int
ir = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
   in (Int -> Int -> (Int, Int)) -> Vector Int -> Vector (Int, Int)
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap ((,) (Int -> Int -> (Int, Int))
-> (Int -> Int) -> Int -> Int -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
il)) (Vector Int -> Vector (Int, Int))
-> Vector Int -> Vector (Int, Int)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.slice Int
il (Int
ir Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
il) Vector Int
adjCsr