-- | Generic tree functions.
-- @since
module AtCoder.Extra.Tree
  ( -- * Tree folding

    -- | These function are built around the three type parameters: \(w\), \(f\) and \(a\).
    -- - \(w\): Edge weight.
    -- - \(f\): Monoid action to a vertex value. These actions are created from vertex value \(a\)
    -- and edge information @(Int, w)@.
    -- - \(a\): Monoid values stored at vertices.

import Data.Functor.Identity (runIdentity)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

{-# INLINE foldImpl #-}
foldImpl ::
  forall m w f a.
  (HasCallStack, Monad m, VU.Unbox w) =>
  (Int -> VU.Vector (Int, w)) ->
  (Int -> a) ->
  (a -> (Int, w) -> f) ->
  (f -> a -> a) ->
  Int ->
  (Int -> a -> m ()) ->
  m a
-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP.
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let gr = Gr.build @(Sum Int) 5 . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.fold (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in getSum res
-- :}
-- 4
-- @since
{-# INLINE fold #-}
fold ::
  (HasCallStack, VU.Unbox w) =>
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree folding result from the root vertex.
-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP. The calculation process on
-- every vertex is recoreded and returned as a vector.
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.scan n (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [0,1,4,1,0]
-- @since
{-# INLINE scan #-}
scan ::
  (VU.Unbox w, VG.Vector v a) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree scanning result from a root vertex.
  v a
-- | \(O(n)\) Folds a tree from every vertex, using the rerooting technique.
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.foldReroot n (gr `Gr.adjW`) valAt toF act
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [4,4,4,4,4]
-- @since
{-# INLINE foldReroot #-}
foldReroot ::
  forall w f a.
  (HasCallStack, VU.Unbox w, VU.Unbox a, VU.Unbox f, Monoid f) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@:Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Tree folding result from every vertex as a root.
  VU.Vector a
