module AtCoder.Extra.Tree
(
fold,
scan,
foldReroot,
)
where
import Data.Functor.Identity (runIdentity)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
{-# INLINE foldImpl #-}
foldImpl ::
forall m w f a.
(HasCallStack, Monad m, VU.Unbox w) =>
(Int -> VU.Vector (Int, w)) ->
(Int -> a) ->
(a -> (Int, w) -> f) ->
(f -> a -> a) ->
Int ->
(Int -> a -> m ()) ->
m a
foldImpl :: forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root Int -> a -> m ()
memo = Int -> Int -> m a
inner (-Int
1) Int
root
where
inner :: Int -> Int -> m a
inner :: Int -> Int -> m a
inner !Int
parent !Int
v1 = do
let !acc0 :: a
acc0 = Int -> a
valAt Int
v1
let !v2s :: Vector (Int, w)
v2s = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
!a
res <- (a -> (Int, w) -> m a) -> a -> Vector (Int, w) -> m a
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM' (\a
acc (!Int
v2, !w
w) -> (f -> a -> a
`act` a
acc) (f -> a) -> (a -> f) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m a
inner Int
v1 Int
v2) a
acc0 Vector (Int, w)
v2s
Int -> a -> m ()
memo Int
v1 a
res
a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res
{-# INLINE fold #-}
fold ::
(HasCallStack, VU.Unbox w) =>
(Int -> VU.Vector (Int, w)) ->
(Int -> a) ->
(a -> (Int, w) -> f) ->
(f -> a -> a) ->
Int ->
a
fold :: forall w a f.
(HasCallStack, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a) -> (a -> (Int, w) -> f) -> (f -> a -> a) -> Int -> a
fold Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root = Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> Identity a -> a
forall a b. (a -> b) -> a -> b
$ do
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> Identity ())
-> Identity a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root (\Int
_ a
_ -> () -> Identity ()
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
{-# INLINE scan #-}
scan ::
(VU.Unbox w, VG.Vector v a) =>
Int ->
(Int -> VU.Vector (Int, w)) ->
(Int -> a) ->
(a -> (Int, w) -> f) ->
(f -> a -> a) ->
Int ->
v a
scan :: forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root = (forall s. ST s (Mutable v s a)) -> v a
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
VG.create ((forall s. ST s (Mutable v s a)) -> v a)
-> (forall s. ST s (Mutable v s a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
Mutable v s a
dp <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.unsafeNew Int
n
!a
_ <- (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> ST s ())
-> ST s a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root ((Int -> a -> ST s ()) -> ST s a)
-> (Int -> a -> ST s ()) -> ST s a
forall a b. (a -> b) -> a -> b
$ \Int
v a
a -> do
Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
dp Int
v a
a
Mutable v s a -> ST s (Mutable v s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Mutable v s a
dp
{-# INLINE foldReroot #-}
foldReroot ::
forall w f a.
(HasCallStack, VU.Unbox w, VU.Unbox a, VU.Unbox f, Monoid f) =>
Int ->
(Int -> VU.Vector (Int, w)) ->
(Int -> a) ->
(a -> (Int, w) -> f) ->
(f -> a -> a) ->
VU.Vector a
foldReroot :: forall w f a.
(HasCallStack, Unbox w, Unbox a, Unbox f, Monoid f) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Vector a
foldReroot Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
!MVector s a
dp <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
n
let reroot :: Int -> f -> Int -> ST s ()
reroot Int
parent f
parentF Int
v1 = do
let !children :: Vector (Int, w)
children = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
let !fL :: Vector f
fL = (f -> (Int, w) -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
VU.scanl' (\ !f
f (!Int
v2, !w
w) -> (f
f <>) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children
let !fR :: Vector f
fR = ((Int, w) -> f -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
VU.scanr' (\(!Int
v2, !w
w) !f
f -> (f -> f -> f
forall a. Semigroup a => a -> a -> a
<> f
f) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children
let !x1 :: a
x1 = (f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> Vector f -> f
forall a. Unbox a => Vector a -> a
VU.last Vector f
fL) f -> a -> a
`act` Int -> a
valAt Int
v1
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
dp Int
v1 a
x1
Vector (Int, w) -> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector (Int, w)
children ((Int -> (Int, w) -> ST s ()) -> ST s ())
-> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i2 (!Int
v2, !w
w) -> do
let !f1 :: f
f1 = f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fL Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i2) f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fR Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
let !v1Acc :: a
v1Acc = f
f1 f -> a -> a
`act` Int -> a
valAt Int
v1
let !f2 :: f
f2 = a -> (Int, w) -> f
toF a
v1Acc (Int
v2, w
w)
Int -> f -> Int -> ST s ()
reroot Int
v1 f
f2 Int
v2
Int -> f -> Int -> ST s ()
reroot (-Int
1 :: Int) f
f0 Int
root0
MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
dp
where
!root0 :: Int
root0 = Int
0 :: Int
!f0 :: f
f0 = forall a. Monoid a => a
mempty @f
!treeDp :: Vector a
treeDp = Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> Vector a
forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root0 :: VU.Vector a