-- | Generic tree functions.
--
-- @since 1.1.0.0
module AtCoder.Extra.Tree
  ( -- * Tree folding

    -- | These function are built around the three type parameters: \(w\), \(f\) and \(a\).
    --
    -- - \(w\): Edge weight.
    -- - \(f\): Monoid action to a vertex value. These actions are created from vertex value \(a\)
    -- and edge information @(Int, w)@.
    -- - \(a\): Monoid values stored at vertices.
    fold,
    scan,
    foldReroot,
  )
where

import Data.Functor.Identity (runIdentity)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

{-# INLINE foldImpl #-}
foldImpl ::
  forall m w f a.
  (HasCallStack, Monad m, VU.Unbox w) =>
  (Int -> VU.Vector (Int, w)) ->
  (Int -> a) ->
  (a -> (Int, w) -> f) ->
  (f -> a -> a) ->
  Int ->
  (Int -> a -> m ()) ->
  m a
foldImpl :: forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root Int -> a -> m ()
memo = Int -> Int -> m a
inner (-Int
1) Int
root
  where
    inner :: Int -> Int -> m a
    inner :: Int -> Int -> m a
inner !Int
parent !Int
v1 = do
      let !acc0 :: a
acc0 = Int -> a
valAt Int
v1
      let !v2s :: Vector (Int, w)
v2s = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
      !a
res <- (a -> (Int, w) -> m a) -> a -> Vector (Int, w) -> m a
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM' (\a
acc (!Int
v2, !w
w) -> (f -> a -> a
`act` a
acc) (f -> a) -> (a -> f) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m a
inner Int
v1 Int
v2) a
acc0 Vector (Int, w)
v2s
      Int -> a -> m ()
memo Int
v1 a
res
      a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res

-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let gr = Gr.build @(Sum Int) 5 . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.fold (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in getSum res
-- :}
-- 4
--
-- @since 1.1.0.0
{-# INLINE fold #-}
fold ::
  (HasCallStack, VU.Unbox w) =>
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree folding result from the root vertex.
  a
fold :: forall w a f.
(HasCallStack, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a) -> (a -> (Int, w) -> f) -> (f -> a -> a) -> Int -> a
fold Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root = Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> Identity a -> a
forall a b. (a -> b) -> a -> b
$ do
  (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> Identity ())
-> Identity a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root (\Int
_ a
_ -> () -> Identity ()
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP. The calculation process on
-- every vertex is recoreded and returned as a vector.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.scan n (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [0,1,4,1,0]
--
-- @since 1.1.0.0
{-# INLINE scan #-}
scan ::
  (VU.Unbox w, VG.Vector v a) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree scanning result from a root vertex.
  v a
scan :: forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root = (forall s. ST s (Mutable v s a)) -> v a
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
VG.create ((forall s. ST s (Mutable v s a)) -> v a)
-> (forall s. ST s (Mutable v s a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
  Mutable v s a
dp <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.unsafeNew Int
n
  !a
_ <- (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> ST s ())
-> ST s a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root ((Int -> a -> ST s ()) -> ST s a)
-> (Int -> a -> ST s ()) -> ST s a
forall a b. (a -> b) -> a -> b
$ \Int
v a
a -> do
    Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
dp Int
v a
a
  Mutable v s a -> ST s (Mutable v s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Mutable v s a
dp

-- | \(O(n)\) Folds a tree from every vertex, using the rerooting technique.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.foldReroot n (gr `Gr.adjW`) valAt toF act
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [4,4,4,4,4]
--
-- @since 1.1.0.0
{-# INLINE foldReroot #-}
foldReroot ::
  forall w f a.
  (HasCallStack, VU.Unbox w, VU.Unbox a, VU.Unbox f, Monoid f) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@:Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Tree folding result from every vertex as a root.
  VU.Vector a
foldReroot :: forall w f a.
(HasCallStack, Unbox w, Unbox a, Unbox f, Monoid f) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Vector a
foldReroot Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  -- Calculate tree DP for every vertex as a root:
  !MVector s a
dp <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
n
  let reroot :: Int -> f -> Int -> ST s ()
reroot Int
parent f
parentF Int
v1 = do
        -- TODO: when the operator is not commutative?
        let !children :: Vector (Int, w)
children = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
        let !fL :: Vector f
fL = (f -> (Int, w) -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
VU.scanl' (\ !f
f (!Int
v2, !w
w) -> (f
f <>) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children
        let !fR :: Vector f
fR = ((Int, w) -> f -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
VU.scanr' (\(!Int
v2, !w
w) !f
f -> (f -> f -> f
forall a. Semigroup a => a -> a -> a
<> f
f) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children

        -- save
        let !x1 :: a
x1 = (f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> Vector f -> f
forall a. Unbox a => Vector a -> a
VU.last Vector f
fL) f -> a -> a
`act` Int -> a
valAt Int
v1
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
dp Int
v1 a
x1

        Vector (Int, w) -> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector (Int, w)
children ((Int -> (Int, w) -> ST s ()) -> ST s ())
-> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i2 (!Int
v2, !w
w) -> do
          -- composited operator excluding @v2@:
          let !f1 :: f
f1 = f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fL Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i2) f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fR Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
          let !v1Acc :: a
v1Acc = f
f1 f -> a -> a
`act` Int -> a
valAt Int
v1
          let !f2 :: f
f2 = a -> (Int, w) -> f
toF a
v1Acc (Int
v2, w
w)
          Int -> f -> Int -> ST s ()
reroot Int
v1 f
f2 Int
v2

  Int -> f -> Int -> ST s ()
reroot (-Int
1 :: Int) f
f0 Int
root0
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
dp
  where
    !root0 :: Int
root0 = Int
0 :: Int
    !f0 :: f
f0 = forall a. Monoid a => a
mempty @f
    !treeDp :: Vector a
treeDp = Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> Vector a
forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root0 :: VU.Vector a