{-# LANGUAGE RecordWildCards #-}

-- | A Fenwick tree, also known as binary indexed tree. Given an array of length \(n\), it processes
-- the following queries in \(O(\log n)\) time.
--
-- - Updating an element
-- - Calculating the sum of the elements of an interval
--
-- ==== __Example__
-- Create a `FenwickTree` with `new`:
--
-- >>> import AtCoder.FenwickTree qualified as FT
-- >>> ft <- FT.new @_ @Int 4 -- [0, 0, 0, 0]
-- >>> FT.nFt ft
-- 4
--
-- It can perform point `add` and range `sum` in \(O(\log n)\) time:
--
-- >>> FT.add ft 0 3          -- [3, 0, 0, 0]
-- >>> FT.sum ft 0 3
-- 3
--
-- >>> FT.add ft 2 3          -- [3, 0, 3, 0]
-- >>> FT.sum ft 0 3
-- 6
--
-- Create a `FenwickTree` with initial values using `build`:
--
-- >>> ft <- FT.build @_ @Int $ VU.fromList [3, 0, 3, 0]
-- >>> FT.add ft 1 2          -- [3, 2, 3, 0]
-- >>> FT.sum ft 0 3
-- 8
--
-- @since 1.0.0.0
module AtCoder.FenwickTree
  ( -- * Fenwick tree
    FenwickTree (nFt),

    -- * Constructors
    new,
    build,

    -- * Adding
    add,

    -- * Accessor
    sum,
    sumMaybe,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad (when)
import Control.Monad.Fix (fix)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bits
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (sum)

-- | A Fenwick tree.
--
-- @since 1.0.0.0
data FenwickTree s a = FenwickTree
  { -- | 1.0.0 The number of vertices.
    --
    -- @since 1.0.0.0
    forall s a. FenwickTree s a -> Int
nFt :: {-# UNPACK #-} !Int,
    -- | The data storage.
    forall s a. FenwickTree s a -> MVector s a
dataFt :: !(VUM.MVector s a)
  }

-- | Creates an array \([a_0, a_1, \cdots, a_{n-1}]\) of length \(n\). All the elements are
-- initialized to \(0\).
--
-- ==== Constraints
-- - \(0 \leq n\)
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
{-# INLINE new #-}
new :: (HasCallStack, PrimMonad m, Num a, VU.Unbox a) => Int -> m (FenwickTree (PrimState m) a)
new :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
Int -> m (FenwickTree (PrimState m) a)
new Int
nFt
  | Int
nFt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
      MVector (PrimState m) a
dataFt <- Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nFt a
0
      FenwickTree (PrimState m) a -> m (FenwickTree (PrimState m) a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FenwickTree {Int
MVector (PrimState m) a
nFt :: Int
dataFt :: MVector (PrimState m) a
nFt :: Int
dataFt :: MVector (PrimState m) a
..}
  | Bool
otherwise = [Char] -> m (FenwickTree (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (FenwickTree (PrimState m) a))
-> [Char] -> m (FenwickTree (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [Char]
"AtCoder.FenwickTree.new: given negative size `" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nFt [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"`"

-- | Creates `FenwickTree` with initial values.
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.0.0.0
build :: (PrimMonad m, Num a, VU.Unbox a) => VU.Vector a -> m (FenwickTree (PrimState m) a)
{-# INLINE build #-}
build :: forall (m :: * -> *) a.
(PrimMonad m, Num a, Unbox a) =>
Vector a -> m (FenwickTree (PrimState m) a)
build Vector a
xs = do
  FenwickTree (PrimState m) a
ft <- Int -> m (FenwickTree (PrimState m) a)
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
Int -> m (FenwickTree (PrimState m) a)
new (Int -> m (FenwickTree (PrimState m) a))
-> Int -> m (FenwickTree (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
xs
  Vector a -> (Int -> a -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector a
xs ((Int -> a -> m ()) -> m ()) -> (Int -> a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ FenwickTree (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> a -> m ()
add FenwickTree (PrimState m) a
ft
  FenwickTree (PrimState m) a -> m (FenwickTree (PrimState m) a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FenwickTree (PrimState m) a
ft

-- | Adds \(x\) to \(p\)-th value of the array.
--
-- ==== Constraints
-- - \(0 \leq l \lt n\)
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- @since 1.0.0.0
{-# INLINE add #-}
add :: (HasCallStack, PrimMonad m, Num a, VU.Unbox a) => FenwickTree (PrimState m) a -> Int -> a -> m ()
add :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> a -> m ()
add FenwickTree {Int
MVector (PrimState m) a
nFt :: forall s a. FenwickTree s a -> Int
dataFt :: forall s a. FenwickTree s a -> MVector s a
nFt :: Int
dataFt :: MVector (PrimState m) a
..} Int
p0 a
x = do
  let !()
_ = HasCallStack => [Char] -> Int -> Int -> ()
[Char] -> Int -> Int -> ()
ACIA.checkIndex [Char]
"AtCoder.FenwickTree.add" Int
p0 Int
nFt
  let p1 :: Int
p1 = Int
p0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  (((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix Int
p1 (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop Int
p -> do
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
nFt) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      MVector (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) a
dataFt (a -> a -> a
forall a. Num a => a -> a -> a
+ a
x) (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$! Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
p))

-- | \(O(\log n)\) Calculates the sum in a half-open interval @[0, r)@.
--
-- @since 1.0.0.0
{-# INLINE prefixSum #-}
prefixSum :: (PrimMonad m, Num a, VU.Unbox a) => FenwickTree (PrimState m) a -> Int -> m a
prefixSum :: forall (m :: * -> *) a.
(PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> m a
prefixSum FenwickTree {Int
MVector (PrimState m) a
nFt :: forall s a. FenwickTree s a -> Int
dataFt :: forall s a. FenwickTree s a -> MVector s a
nFt :: Int
dataFt :: MVector (PrimState m) a
..} = a -> Int -> m a
inner a
0
  where
    inner :: a -> Int -> m a
inner !a
acc !Int
r
      | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
acc
      | Bool
otherwise = do
          a
dx <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataFt (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
          a -> Int -> m a
inner (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
dx) (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
r))

-- | Calculates the sum in a half-open interval \([l, r)\).
--
-- ==== Constraints
-- - \(0 \leq l \leq r \leq n\)
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- @since 1.0.0.0
{-# INLINE sum #-}
sum :: (HasCallStack, PrimMonad m, Num a, VU.Unbox a) => FenwickTree (PrimState m) a -> Int -> Int -> m a
sum :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> Int -> m a
sum ft :: FenwickTree (PrimState m) a
ft@FenwickTree {Int
nFt :: forall s a. FenwickTree s a -> Int
nFt :: Int
nFt} Int
l Int
r
  | Bool -> Bool
not (Int -> Int -> Int -> Bool
ACIA.testInterval Int
l Int
r Int
nFt) = [Char] -> Int -> Int -> Int -> m a
forall a. HasCallStack => [Char] -> Int -> Int -> Int -> a
ACIA.errorInterval [Char]
"AtCoder.FenwickTree.sum" Int
l Int
r Int
nFt
  | Bool
otherwise = FenwickTree (PrimState m) a -> Int -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> Int -> m a
unsafeSum FenwickTree (PrimState m) a
ft Int
l Int
r

-- | Total variant of `sum`. Calculates the sum in a half-open interval \([l, r)\). It returns
-- `Nothing` if the interval is invalid.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- @since 1.0.0.0
{-# INLINE sumMaybe #-}
sumMaybe :: (HasCallStack, PrimMonad m, Num a, VU.Unbox a) => FenwickTree (PrimState m) a -> Int -> Int -> m (Maybe a)
sumMaybe :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> Int -> m (Maybe a)
sumMaybe ft :: FenwickTree (PrimState m) a
ft@FenwickTree {Int
nFt :: forall s a. FenwickTree s a -> Int
nFt :: Int
nFt} Int
l Int
r
  | Bool -> Bool
not (Int -> Int -> Int -> Bool
ACIA.testInterval Int
l Int
r Int
nFt) = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
  | Bool
otherwise = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FenwickTree (PrimState m) a -> Int -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> Int -> m a
unsafeSum FenwickTree (PrimState m) a
ft Int
l Int
r

-- | Internal implementation of `sum`.
{-# INLINE unsafeSum #-}
unsafeSum :: (HasCallStack, PrimMonad m, Num a, VU.Unbox a) => FenwickTree (PrimState m) a -> Int -> Int -> m a
unsafeSum :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> Int -> m a
unsafeSum FenwickTree (PrimState m) a
ft Int
l Int
r = do
  a
xr <- FenwickTree (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> m a
prefixSum FenwickTree (PrimState m) a
ft Int
r
  a
xl <- FenwickTree (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Num a, Unbox a) =>
FenwickTree (PrimState m) a -> Int -> m a
prefixSum FenwickTree (PrimState m) a
ft Int
l
  a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$! a
xr a -> a -> a
forall a. Num a => a -> a -> a
- a
xl