{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_HADDOCK hide #-}

-- | Internal CSR for `AtCoder.MinCostFlow`.
--
-- @since 1.0.0.0
module AtCoder.Internal.McfCsr
  ( -- * Compressed sparse row
    Csr (..),

    -- * Constructor
    build,

    -- * Accessor
    adj,
  )
where

import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Base qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | CSR for min cost flow.
--
-- @since 1.0.0.0
data Csr s cap cost = Csr
  { -- | @since 1.0.0.0
    forall s cap cost. Csr s cap cost -> Vector Int
startCsr :: !(VU.Vector Int),
    -- | @since 1.0.0.0
    forall s cap cost. Csr s cap cost -> Vector Int
toCsr :: !(VU.Vector Int),
    -- | @since 1.0.0.0
    forall s cap cost. Csr s cap cost -> Vector Int
revCsr :: !(VU.Vector Int),
    -- | Mutable.
    --
    -- @since 1.0.0.0
    forall s cap cost. Csr s cap cost -> MVector s cap
capCsr :: !(VUM.MVector s cap),
    -- | @since 1.0.0.0
    forall s cap cost. Csr s cap cost -> Vector cost
costCsr :: !(VU.Vector cost)
  }

-- | \(O(n + m)\) Creates `Csr`.
--
-- @since 1.0.0.0
{-# INLINE build #-}
build :: (HasCallStack, Num cap, VU.Unbox cap, VU.Unbox cost, Num cost, PrimMonad m) => Int -> VU.Vector (Int, Int, cap, cap, cost) -> m (VU.Vector Int, Csr (PrimState m) cap cost)
build :: forall cap cost (m :: * -> *).
(HasCallStack, Num cap, Unbox cap, Unbox cost, Num cost,
 PrimMonad m) =>
Int
-> Vector (Int, Int, cap, cap, cost)
-> m (Vector Int, Csr (PrimState m) cap cost)
build Int
n Vector (Int, Int, cap, cap, cost)
edges = do
  let m :: Int
m = Vector (Int, Int, cap, cap, cost) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, cap, cap, cost)
edges
  -- craete the offsets first (this is a different step from ac-librar)
  let startCsr :: Vector Int
startCsr = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
        MVector s Int
start <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
0 :: Int)
        -- count degrees
        let (VU.V_5 Int
_ Vector Int
froms Vector Int
tos Vector cap
_ Vector cap
_ Vector cost
_) = Vector (Int, Int, cap, cap, cost)
edges
        Vector (Int, Int) -> ((Int, Int) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ (Vector Int -> Vector Int -> Vector (Int, Int)
forall a b.
(Unbox a, Unbox b) =>
Vector a -> Vector b -> Vector (a, b)
VU.zip Vector Int
froms Vector Int
tos) (((Int, Int) -> ST s ()) -> ST s ())
-> ((Int, Int) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(!Int
from, !Int
to) -> do
          MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
start (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
from Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
          MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
start (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
to Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        -- sum up the degrees
        MVector (PrimState (ST s)) Int
-> (Int -> Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (Int -> a -> m b) -> m ()
VUM.iforM_ (MVector s Int -> MVector s Int
forall a s. Unbox a => MVector s a -> MVector s a
VUM.init MVector s Int
start) ((Int -> Int -> ST s ()) -> ST s ())
-> (Int -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i Int
dx -> do
          MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
start (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dx) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
start

  MVector (PrimState m) Int
toVec <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> m (MVector (PrimState m) Int))
-> Int -> m (MVector (PrimState m) Int)
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m
  MVector (PrimState m) Int
revVec <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> m (MVector (PrimState m) Int))
-> Int -> m (MVector (PrimState m) Int)
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m
  MVector (PrimState m) cap
capCsr <- Int -> m (MVector (PrimState m) cap)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> m (MVector (PrimState m) cap))
-> Int -> m (MVector (PrimState m) cap)
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m
  MVector (PrimState m) cost
costVec <- Int -> m (MVector (PrimState m) cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> m (MVector (PrimState m) cost))
-> Int -> m (MVector (PrimState m) cost)
forall a b. (a -> b) -> a -> b
$ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m

  -- build CSR
  MVector (PrimState m) Int
counter <- Vector Int -> m (MVector (PrimState m) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector Int
startCsr
  Vector Int
edgeIdx <- Vector (Int, Int, cap, cap, cost)
-> ((Int, Int, cap, cap, cost) -> m Int) -> m (Vector Int)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
Vector a -> (a -> m b) -> m (Vector b)
VU.forM Vector (Int, Int, cap, cap, cost)
edges (((Int, Int, cap, cap, cost) -> m Int) -> m (Vector Int))
-> ((Int, Int, cap, cap, cost) -> m Int) -> m (Vector Int)
forall a b. (a -> b) -> a -> b
$ \(!Int
from, !Int
to, !cap
cap, !cap
flow, !cost
cost) -> do
    Int
i1 <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
counter Int
from
    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
counter (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
from
    Int
i2 <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
counter Int
to
    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
counter (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
to
    -- write forward edge
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
toVec Int
i1 Int
to
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
revVec Int
i1 Int
i2
    MVector (PrimState m) cap -> Int -> cap -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) cap
capCsr Int
i1 (cap -> m ()) -> cap -> m ()
forall a b. (a -> b) -> a -> b
$! cap
cap cap -> cap -> cap
forall a. Num a => a -> a -> a
- cap
flow
    MVector (PrimState m) cost -> Int -> cost -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) cost
costVec Int
i1 cost
cost
    -- write backward edge
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
toVec Int
i2 Int
from
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
revVec Int
i2 Int
i1
    MVector (PrimState m) cap -> Int -> cap -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) cap
capCsr Int
i2 cap
flow
    MVector (PrimState m) cost -> Int -> cost -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) cost
costVec Int
i2 (-cost
cost)
    -- remember forward edge index
    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i1

  Vector Int
toCsr <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
toVec
  Vector Int
revCsr <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
revVec
  Vector cost
costCsr <- MVector (PrimState m) cost -> m (Vector cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) cost
costVec
  (Vector Int, Csr (PrimState m) cap cost)
-> m (Vector Int, Csr (PrimState m) cap cost)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector Int
edgeIdx, Csr {MVector (PrimState m) cap
Vector cost
Vector Int
startCsr :: Vector Int
toCsr :: Vector Int
revCsr :: Vector Int
capCsr :: MVector (PrimState m) cap
costCsr :: Vector cost
startCsr :: Vector Int
capCsr :: MVector (PrimState m) cap
toCsr :: Vector Int
revCsr :: Vector Int
costCsr :: Vector cost
..})

-- | \(O(1)\) Returns a vector of @(to, rev, cost)@.
--
-- @since 1.0.0.0
{-# INLINE adj #-}
adj :: (HasCallStack, Num cap, VU.Unbox cap, VU.Unbox cost) => Csr s cap cost -> Int -> VU.Vector (Int, Int, cost)
adj :: forall cap cost s.
(HasCallStack, Num cap, Unbox cap, Unbox cost) =>
Csr s cap cost -> Int -> Vector (Int, Int, cost)
adj Csr {MVector s cap
Vector cost
Vector Int
startCsr :: forall s cap cost. Csr s cap cost -> Vector Int
toCsr :: forall s cap cost. Csr s cap cost -> Vector Int
revCsr :: forall s cap cost. Csr s cap cost -> Vector Int
capCsr :: forall s cap cost. Csr s cap cost -> MVector s cap
costCsr :: forall s cap cost. Csr s cap cost -> Vector cost
startCsr :: Vector Int
toCsr :: Vector Int
revCsr :: Vector Int
capCsr :: MVector s cap
costCsr :: Vector cost
..} Int
v = Int -> Int -> Vector (Int, Int, cost) -> Vector (Int, Int, cost)
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.slice Int
offset Int
len Vector (Int, Int, cost)
vec
  where
    offset :: Int
offset = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v
    len :: Int
len = Vector Int
startCsr Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
offset
    vec :: Vector (Int, Int, cost)
vec = Vector Int -> Vector Int -> Vector cost -> Vector (Int, Int, cost)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector a -> Vector b -> Vector c -> Vector (a, b, c)
VU.zip3 Vector Int
toCsr Vector Int
revCsr Vector cost
costCsr