{-# LANGUAGE RecordWildCards #-}

-- original implementation:
-- <https://miti-7.hatenablog.com/entry/2018/04/28/152259>

-- NOTE: We could integrate cumulative sum / fenwick tree / segment tree.
-- NOTE: `topK` and `intersects` are not implemented as they are slow.

-- | A static Wavelet Matrix without automatic index comperssion. Consider using
-- @AtCoder.Extra.WaveletMatrix@ instead.
--
-- @since 1.1.0.0
module AtCoder.Extra.WaveletMatrix.Raw
  ( -- * RawWaveletMatrix
    RawWaveletMatrix (..),

    -- * Constructors
    build,

    -- * Access (indexing)
    access,

    -- * rank
    rankLT,
    rank,
    rankBetween,

    -- * Select
    select,
    selectKth,
    selectIn,
    selectKthIn,

    -- * Quantile (value-ordered access)

    -- ** Safe (total)
    kthSmallestIn,
    ikthSmallestIn,
    kthLargestIn,
    ikthLargestIn,

    -- ** Unsafe
    unsafeKthSmallestIn,
    unsafeIKthSmallestIn,
    unsafeKthLargestIn,
    unsafeIKthLargestIn,

    -- * Lookup
    lookupLE,
    lookupLT,
    lookupGE,
    lookupGT,

    -- * Conversions
    assocsIn,
    assocsWith,
    descAssocsIn,
    descAssocsInWith,
  )
where

import AtCoder.Extra.WaveletMatrix.BitVector qualified as BV
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Bit qualified as ACIB
import Control.Monad.ST (runST)
import Data.Bit (Bit (..))
import Data.Bits (bit, countTrailingZeros, setBit, testBit, (.|.))
import Data.Maybe
import Data.Vector qualified as V
import Data.Vector.Algorithms.Radix qualified as VAR
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | A static Wavelet Matrix without automatic index comperssion.
--
-- @since 1.1.0.0
data RawWaveletMatrix = RawWaveletMatrix
  { -- | \(\lceil \log_2 N \rceil\).
    --
    -- @since 1.1.0.0
    RawWaveletMatrix -> Int
heightRwm :: {-# UNPACK #-} !Int,
    -- | The length of the original array.
    --
    -- @since 1.1.0.0
    RawWaveletMatrix -> Int
lengthRwm :: {-# UNPACK #-} !Int,
    -- | The bit matrix. Each row represents (heightRwm - 1 - iRow) bit's on/off.
    --
    -- @since 1.1.0.0
    RawWaveletMatrix -> Vector BitVector
bitsRwm :: !(V.Vector BV.BitVector),
    -- | The number of zeros bits in each row in the bit matrix.
    --
    -- @since 1.1.0.0
    RawWaveletMatrix -> Vector Int
nZerosRwm :: !(VU.Vector Int)
  }
  deriving (RawWaveletMatrix -> RawWaveletMatrix -> Bool
(RawWaveletMatrix -> RawWaveletMatrix -> Bool)
-> (RawWaveletMatrix -> RawWaveletMatrix -> Bool)
-> Eq RawWaveletMatrix
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RawWaveletMatrix -> RawWaveletMatrix -> Bool
== :: RawWaveletMatrix -> RawWaveletMatrix -> Bool
$c/= :: RawWaveletMatrix -> RawWaveletMatrix -> Bool
/= :: RawWaveletMatrix -> RawWaveletMatrix -> Bool
Eq, Int -> RawWaveletMatrix -> ShowS
[RawWaveletMatrix] -> ShowS
RawWaveletMatrix -> String
(Int -> RawWaveletMatrix -> ShowS)
-> (RawWaveletMatrix -> String)
-> ([RawWaveletMatrix] -> ShowS)
-> Show RawWaveletMatrix
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RawWaveletMatrix -> ShowS
showsPrec :: Int -> RawWaveletMatrix -> ShowS
$cshow :: RawWaveletMatrix -> String
show :: RawWaveletMatrix -> String
$cshowList :: [RawWaveletMatrix] -> ShowS
showList :: [RawWaveletMatrix] -> ShowS
Show)

-- | \(O(n \log n)\) Creates a `RawWaveletMatrix` from a vector \(a\).
--
-- @since 1.1.0.0
{-# INLINE build #-}
build ::
  (HasCallStack) =>
  -- | The number of different values in the compressed vector.
  Int ->
  -- | A compressed vector
  VU.Vector Int ->
  -- | A wavelet matrix
  RawWaveletMatrix
build :: HasCallStack => Int -> Vector Int -> RawWaveletMatrix
build Int
nx Vector Int
xs
  | Int
nx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> RawWaveletMatrix
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.WaveletMatrix.Raw.build: given negative `n`"
  | Bool
otherwise = (forall s. ST s RawWaveletMatrix) -> RawWaveletMatrix
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s RawWaveletMatrix) -> RawWaveletMatrix)
-> (forall s. ST s RawWaveletMatrix) -> RawWaveletMatrix
forall a b. (a -> b) -> a -> b
$ do
      -- TODO: less mutable variables
      MVector s Bit
orgBits <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
lengthRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
heightRwm) (Bit -> ST s (MVector (PrimState (ST s)) Bit))
-> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall a b. (a -> b) -> a -> b
$ Bool -> Bit
Bit Bool
False
      MVector s Int
orgCsum <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
lenCSum Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
heightRwm) (Int
0 :: Int)
      MVector s Int
nZeros <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
heightRwm

      -- views by row over the contiguous memory:
      let !bits :: Vector (MVector s Bit)
bits = Int
-> (MVector s Bit -> (MVector s Bit, MVector s Bit))
-> MVector s Bit
-> Vector (MVector s Bit)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
heightRwm (Int -> MVector s Bit -> (MVector s Bit, MVector s Bit)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
lengthRwm) MVector s Bit
orgBits
      let !csums :: Vector (MVector s Int)
csums = Int
-> (MVector s Int -> (MVector s Int, MVector s Int))
-> MVector s Int
-> Vector (MVector s Int)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
heightRwm (Int -> MVector s Int -> (MVector s Int, MVector s Int)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
lenCSum) MVector s Int
orgCsum

      -- the vector will be sorted by bits.
      MVector s Int
vec <- Vector Int -> ST s (MVector (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector Int
xs
      (Int -> MVector s Bit -> MVector s Int -> ST s ())
-> Vector (MVector s Bit) -> Vector (MVector s Int) -> ST s ()
forall (m :: * -> *) a b c.
Monad m =>
(Int -> a -> b -> m c) -> Vector a -> Vector b -> m ()
V.izipWithM_
        ( \Int
iRow MVector s Bit
bitVec MVector s Int
csum -> do
            let !iBit :: Int
iBit = Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow
            Vector Int
vec' <- MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
vec
            Vector Int -> (Int -> Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
vec' ((Int -> Int -> ST s ()) -> ST s ())
-> (Int -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i Int
x -> do
              MVector (PrimState (ST s)) Bit -> Int -> Bit -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s Bit
MVector (PrimState (ST s)) Bit
bitVec Int
i (Bit -> ST s ()) -> (Bool -> Bit) -> Bool -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Bit
Bit (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
x Int
iBit

            -- csum.
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
csum Int
0 (Int
0 :: Int)
            Vector Bit
bitVec' <- MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
bitVec

            -- get popCount by word. TODO: use `castToWords` for most elements
            Int
nOnes <- MVector (PrimState (ST s)) Int -> Vector Bit -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> Vector Bit -> m Int
BV.csumInPlace MVector s Int
MVector (PrimState (ST s)) Int
csum Vector Bit
bitVec'
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
nZeros Int
iRow (Int
lengthRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
nOnes)

            -- preform a stable sort by the bit:
            Int
-> Int
-> (Int -> Int -> Int)
-> MVector (PrimState (ST s)) Int
-> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
VAR.sortBy Int
2 Int
2 (\Int
_ Int
x -> Bool -> Int
forall a. Enum a => a -> Int
fromEnum (Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
x Int
iBit)) MVector s Int
MVector (PrimState (ST s)) Int
vec
        )
        Vector (MVector s Bit)
bits
        Vector (MVector s Int)
csums
      Vector Int
nZerosRwm <- MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
nZeros
      Vector (Vector Bit)
bits' <- Int
-> (Vector Bit -> (Vector Bit, Vector Bit))
-> Vector Bit
-> Vector (Vector Bit)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
heightRwm (Int -> Vector Bit -> (Vector Bit, Vector Bit)
forall a. Unbox a => Int -> Vector a -> (Vector a, Vector a)
VU.splitAt Int
lengthRwm) (Vector Bit -> Vector (Vector Bit))
-> ST s (Vector Bit) -> ST s (Vector (Vector Bit))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
orgBits
      Vector (Vector Int)
csums' <- Int
-> (Vector Int -> (Vector Int, Vector Int))
-> Vector Int
-> Vector (Vector Int)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
heightRwm (Int -> Vector Int -> (Vector Int, Vector Int)
forall a. Unbox a => Int -> Vector a -> (Vector a, Vector a)
VU.splitAt Int
lenCSum) (Vector Int -> Vector (Vector Int))
-> ST s (Vector Int) -> ST s (Vector (Vector Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
orgCsum
      let !bitsRwm :: Vector BitVector
bitsRwm = (Vector Bit -> Vector Int -> BitVector)
-> Vector (Vector Bit) -> Vector (Vector Int) -> Vector BitVector
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Vector Bit -> Vector Int -> BitVector
BV.BitVector Vector (Vector Bit)
bits' Vector (Vector Int)
csums'
      RawWaveletMatrix -> ST s RawWaveletMatrix
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RawWaveletMatrix -> ST s RawWaveletMatrix)
-> RawWaveletMatrix -> ST s RawWaveletMatrix
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
lengthRwm :: Int
heightRwm :: Int
nZerosRwm :: Vector Int
bitsRwm :: Vector BitVector
..}
  where
    !lengthRwm :: Int
lengthRwm = Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
xs
    !lenCSum :: Int
lenCSum = (Int
lengthRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
BV.wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
BV.wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- +1 for the zero
    !heightRwm :: Int
heightRwm = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
ACIB.bitCeil Int
nx

-- | \(O(\log |S|)\) Returns \(a[k]\) or `Nothing` if the index is out of the bounds. Try to use the
-- original array if you can.
--
-- @since 1.1.0.0
{-# INLINE access #-}
access :: RawWaveletMatrix -> Int -> Maybe Int
access :: RawWaveletMatrix -> Int -> Maybe Int
access RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
i0
  | HasCallStack => Int -> Int -> Bool
Int -> Int -> Bool
ACIA.testIndex Int
i0 Int
lengthRwm =
      let (!Int
_, !Int
res) =
            ((Int, Int) -> Int -> BitVector -> (Int, Int))
-> (Int, Int) -> Vector BitVector -> (Int, Int)
forall a b. (a -> Int -> b -> a) -> a -> Vector b -> a
V.ifoldl'
              ( \(!Int
i, !Int
acc) !Int
iRow !BitVector
bits ->
                  let Bit !Bool
goRight = Vector Bit -> Int -> Bit
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex (BitVector -> Vector Bit
BV.bitsBv BitVector
bits) Int
i
                      !i' :: Int
i'
                        | Bool
goRight = BitVector -> Int -> Int
BV.rank1 BitVector
bits Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
nZerosRwm Int
iRow
                        | Bool
otherwise = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
i
                      !acc' :: Int
acc'
                        | Bool
goRight = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
setBit Int
acc (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)
                        | Bool
otherwise = Int
acc
                   in (Int
i', Int
acc')
              )
              (Int
i0, Int
0)
              Vector BitVector
bitsRwm
       in Int -> Maybe Int
forall a. a -> Maybe a
Just Int
res
  | Bool
otherwise = Maybe Int
forall a. Maybe a
Nothing

-- | \(O(\log |A|)\) Goes down the wavelet matrix for collecting the kth smallest value.
--
-- @since 1.1.0.0
{-# INLINE goDown #-}
goDown :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int, Int, Int)
goDown :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int, Int, Int)
goDown RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
l_ Int
r_ Int
k_ = ((Int, Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int, Int))
-> (Int, Int, Int, Int) -> Vector BitVector -> (Int, Int, Int, Int)
forall a b. (a -> Int -> b -> a) -> a -> Vector b -> a
V.ifoldl' (Int, Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int, Int)
step (Int
0 :: Int, Int
l_, Int
r_, Int
k_) Vector BitVector
bitsRwm
  where
    -- It's binary search over the value range. In each row, we'll focus on either 0 bit values or
    -- 1 bit values in [l, r) and update the range to [l', r').
    step :: (Int, Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int, Int)
step (!Int
acc, !Int
l, !Int
r, !Int
k) !Int
iRow !BitVector
bits
      -- `r0 - l0`, the number of zeros in [l, r), is bigger than or equal to k:
      -- Go left.
      | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0 = (Int
acc, Int
l0, Int
r0, Int
k)
      -- Go right.
      | Bool
otherwise =
          let !acc' :: Int
acc' = Int
acc Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int -> Int
forall a. Bits a => Int -> a
bit (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)
              !nZeros :: Int
nZeros = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
nZerosRwm Int
iRow
              -- every zero bits come to the left after the move.
              !l' :: Int
l' = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0 -- add the number of zeros in [0, l)
              !r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0 -- add the number of zeros in [0, r)
              !k' :: Int
k' = Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0) -- `r0 - l0` zeros go left
           in (Int
acc', Int
l', Int
r', Int
k')
      where
        !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
        !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r

-- | \(O(\log |A|)\) Goes up the wavelet matrix for collecting the value \(x\).
--
-- @since 1.1.0.0
{-# INLINE goUp #-}
goUp :: RawWaveletMatrix -> Int -> Int -> Maybe Int
goUp :: RawWaveletMatrix -> Int -> Int -> Maybe Int
goUp RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
i0 Int
x =
  (Int -> Int -> BitVector -> Maybe Int)
-> Int -> Vector BitVector -> Maybe Int
forall (m :: * -> *) a b.
Monad m =>
(a -> Int -> b -> m a) -> a -> Vector b -> m a
V.ifoldM'
    ( \ !Int
i !Int
iBit !BitVector
bits ->
        if Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
x Int
iBit
          then BitVector -> Int -> Maybe Int
BV.select1 BitVector
bits (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Int
nZerosRwm Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iBit)
          else BitVector -> Int -> Maybe Int
BV.select0 BitVector
bits Int
i
    )
    Int
i0
    (Vector BitVector -> Vector BitVector
forall a. Vector a -> Vector a
V.reverse Vector BitVector
bitsRwm)

-- | \(O(\log |S|)\) Returns the number of \(y\) in \([l, r) \times [0, y_0)\).
--
-- @since 1.1.0.0
{-# INLINE rankLT #-}
rankLT :: RawWaveletMatrix -> Int -> Int -> Int -> Int
rankLT :: RawWaveletMatrix -> Int -> Int -> Int -> Int
rankLT RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
l_ Int
r_ Int
xr
  -- REMARK: This is required. The function below cannot handle the case N = 2^i and xr = N.
  | Int
xr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Int
forall a. Bits a => Int -> a
bit Int
heightRwm = Int
r'_ Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l'_
  | Int
xr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Int
0
  | Int
r'_ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l'_ = Int
0
  | Bool
otherwise =
      let (!Int
res, !Int
_, !Int
_) = ((Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int))
-> (Int, Int, Int) -> Vector BitVector -> (Int, Int, Int)
forall a b. (a -> Int -> b -> a) -> a -> Vector b -> a
V.ifoldl' (Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int)
step (Int
0, Int
l'_, Int
r'_) Vector BitVector
bitsRwm
       in Int
res
  where
    -- clamp
    l'_ :: Int
l'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l_
    r'_ :: Int
r'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lengthRwm Int
r_
    -- [l, r)
    step :: (Int, Int, Int) -> Int -> BitVector -> (Int, Int, Int)
step (!Int
acc, !Int
l, !Int
r) !Int
iRow !BitVector
bits =
      let !b :: Bool
b = Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
xr (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)
          !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
          !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r
       in if Bool
b
            then (Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0, Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
nZerosRwm Int
iRow, Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
nZerosRwm Int
iRow)
            else (Int
acc, Int
l0, Int
r0)

-- | \(O(\log |S|)\) Returns the number of \(y\) in \([l, r)\).
--
-- @since 1.1.0.0
{-# INLINE rank #-}
rank ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(y\)
  Int ->
  -- | The number of \(y\) in \([l, r)\).
  Int
rank :: RawWaveletMatrix -> Int -> Int -> Int -> Int
rank RawWaveletMatrix
wm Int
l Int
r Int
x = RawWaveletMatrix -> Int -> Int -> Int -> Int -> Int
rankBetween RawWaveletMatrix
wm Int
l Int
r Int
x (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | \(O(\log |S|)\) Returns the number of \(y\) in \([l, r) \times [y_1, y_2)\).
--
-- @since 1.1.0.0
{-# INLINE rankBetween #-}
rankBetween ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(y_1\)
  Int ->
  -- | \(y_2\)
  Int ->
  -- | The number of \(y\) in \([l, r) \times [y_1, y_2)\).
  Int
rankBetween :: RawWaveletMatrix -> Int -> Int -> Int -> Int -> Int
rankBetween RawWaveletMatrix
wm Int
l Int
r Int
lx Int
rx = RawWaveletMatrix -> Int -> Int -> Int -> Int
rankLT RawWaveletMatrix
wm Int
l Int
r Int
rx Int -> Int -> Int
forall a. Num a => a -> a -> a
- RawWaveletMatrix -> Int -> Int -> Int -> Int
rankLT RawWaveletMatrix
wm Int
l Int
r Int
lx

-- | \(O(\log |S|)\) Returns the index of the first \(y\) in the sequence, or `Nothing` if \(y\) is
-- not found.
--
-- @since 1.1.0.0
{-# INLINE select #-}
select :: RawWaveletMatrix -> Int -> Maybe Int
select :: RawWaveletMatrix -> Int -> Maybe Int
select RawWaveletMatrix
wm = RawWaveletMatrix -> Int -> Int -> Maybe Int
selectKth RawWaveletMatrix
wm Int
0

-- | \(O(\log |S|)\) Returns the index of the \(k\)-th occurrence (0-based) of \(y\), or `Nothing`
-- if no such occurrence exists.
--
-- @since 1.1.0.0
{-# INLINE selectKth #-}
selectKth ::
  RawWaveletMatrix ->
  -- | \(k\)
  Int ->
  -- | \(y\)
  Int ->
  -- | The index of \(k\)-th \(y\)
  Maybe Int
selectKth :: RawWaveletMatrix -> Int -> Int -> Maybe Int
selectKth RawWaveletMatrix
wm = RawWaveletMatrix -> Int -> Int -> Int -> Int -> Maybe Int
selectKthIn RawWaveletMatrix
wm Int
0 (RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm)

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns the index of the first occurrence
-- (0-based) of \(y\) in the sequence, or `Nothing` if no such occurrence exists.
--
-- @since 1.1.0.0
{-# INLINE selectIn #-}
selectIn ::
  -- | A wavelet matrix
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | The index of the first \(y\) in \([l, r)\).
  Maybe Int
selectIn :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
selectIn RawWaveletMatrix
wm = RawWaveletMatrix -> Int -> Int -> Int -> Int -> Maybe Int
selectKthIn RawWaveletMatrix
wm Int
0

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns the index of the \(k\)-th occurrence
-- (0-based) of \(y\) in the sequence, or `Nothing` if no such occurrence exists.
--
-- @since 1.1.0.0
{-# INLINE selectKthIn #-}
selectKthIn ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | \(y\)
  Int ->
  -- | The index of the \(k\)-th \(y\) in \([l, r)\).
  Maybe Int
selectKthIn :: RawWaveletMatrix -> Int -> Int -> Int -> Int -> Maybe Int
selectKthIn wm :: RawWaveletMatrix
wm@RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
l_ Int
r_ Int
k Int
x
  | Bool -> Bool
not (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
x Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lengthRwm Bool -> Bool -> Bool
&& Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k Bool -> Bool -> Bool
&& Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lengthRwm) = Maybe Int
forall a. Maybe a
Nothing
  | Int
l'_ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r'_ = Maybe Int
inner
  | Bool
otherwise = Maybe Int
forall a. Maybe a
Nothing
  where
    -- clamp
    l'_ :: Int
l'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l_
    r'_ :: Int
r'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lengthRwm Int
r_
    inner :: Maybe Int
    inner :: Maybe Int
inner
      | Int
rEnd Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
lEnd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k = Maybe Int
forall a. Maybe a
Nothing
      -- go up
      | Bool
otherwise = RawWaveletMatrix -> Int -> Int -> Maybe Int
goUp RawWaveletMatrix
wm (Int
lEnd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) Int
x
      where
        -- TODO: replace with goDown
        -- Go down. Gets the [l, r) range of @x@ in the last array.
        (!Int
lEnd, !Int
rEnd) =
          ((Int, Int) -> Int -> BitVector -> (Int, Int))
-> (Int, Int) -> Vector BitVector -> (Int, Int)
forall a b. (a -> Int -> b -> a) -> a -> Vector b -> a
V.ifoldl'
            ( \(!Int
l, !Int
r) !Int
iRow !BitVector
bits ->
                let !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
                    !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r
                 in if Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
x (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)
                      then (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector Int
nZerosRwm Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0, Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector Int
nZerosRwm Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0)
                      else (Int
l0, Int
r0)
            )
            (Int
l'_, Int
r'_)
            Vector BitVector
bitsRwm

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns the index of the \(k\)-th (0-based)
-- largest value. Note that duplicated values are counted as distinct occurrences.
--
-- @since 1.1.0.0
{-# INLINE kthLargestIn #-}
kthLargestIn ::
  -- | A wavelet matrix
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | \(k\)-th largest \(y\) in \([l, r)\)
  Maybe Int
kthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
kthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l = Maybe Int
forall a. Maybe a
Nothing
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Bool
otherwise = Maybe Int
forall a. Maybe a
Nothing

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns both the index and the value of the
-- \(k\)-th (0-based) largest value. Note that duplicated values are counted as distinct occurrences.
--
-- @since 1.1.0.0
{-# INLINE ikthLargestIn #-}
ikthLargestIn ::
  -- | A wavelet matrix
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | \((i, y)\) for \(k\)-th largest \(y\) in \([l, r)\)
  Maybe (Int, Int)
ikthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe (Int, Int)
ikthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l = Maybe (Int, Int)
forall a. Maybe a
Nothing
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just ((Int, Int) -> Maybe (Int, Int)) -> (Int, Int) -> Maybe (Int, Int)
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Bool
otherwise = Maybe (Int, Int)
forall a. Maybe a
Nothing

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns the index of the \(k\)-th (0-based)
-- smallest value. Note that duplicated values are counted as distinct occurrences.
--
-- @since 1.1.0.0
{-# INLINE kthSmallestIn #-}
kthSmallestIn ::
  -- | A wavelet matrix
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | \(k\)-th largest \(y\) in \([l, r)\)
  Maybe Int
kthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
kthSmallestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l = Maybe Int
forall a. Maybe a
Nothing
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Bool
otherwise = Maybe Int
forall a. Maybe a
Nothing

-- | \(O(\log |S|)\) Given an interval \([l, r)\), it returns both the index and the value of the
-- \(k\)-th (0-based) smallest value. Note that duplicated values are counted as distinct occurrences.
--
-- @since 1.1.0.0
{-# INLINE ikthSmallestIn #-}
ikthSmallestIn ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | \((i, y)\) for \(k\)-th largest \(y\) in \([l, r)\)
  Maybe (Int, Int)
ikthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe (Int, Int)
ikthSmallestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l = Maybe (Int, Int)
forall a. Maybe a
Nothing
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just ((Int, Int) -> Maybe (Int, Int)) -> (Int, Int) -> Maybe (Int, Int)
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthSmallestIn RawWaveletMatrix
wm Int
l Int
r Int
k
  | Bool
otherwise = Maybe (Int, Int)
forall a. Maybe a
Nothing

-- | \(O(\log a)\) Returns \(k\)-th (0-based) biggest number in \([l, r)\). Note that duplicated
-- values are counted as distinct occurrences.
--
-- @since 1.1.0.0
{-# INLINE unsafeKthLargestIn #-}
unsafeKthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k = RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn RawWaveletMatrix
wm Int
l Int
r (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

-- | \(O(\log a)\)
--
-- @since 1.1.0.0
{-# INLINE unsafeIKthLargestIn #-}
unsafeIKthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthLargestIn :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthLargestIn RawWaveletMatrix
wm Int
l Int
r Int
k = RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthSmallestIn RawWaveletMatrix
wm Int
l Int
r (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

-- | \(O(\log a)\)
--
-- @since 1.1.0.0
{-# INLINE unsafeKthSmallestIn #-}
unsafeKthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn RawWaveletMatrix
wm Int
l_ Int
r_ Int
k_ =
  let (!Int
x, !Int
_, !Int
_, !Int
_) = RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int, Int, Int)
goDown RawWaveletMatrix
wm Int
l_ Int
r_ Int
k_
   in Int
x

-- | \(O(\log a)\)
--
-- @since 1.1.0.0
{-# INLINE unsafeIKthSmallestIn #-}
unsafeIKthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthSmallestIn :: RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int)
unsafeIKthSmallestIn RawWaveletMatrix
wm Int
l_ Int
r_ Int
k_ =
  let (!Int
x, !Int
l, !Int
_, !Int
k) = RawWaveletMatrix -> Int -> Int -> Int -> (Int, Int, Int, Int)
goDown RawWaveletMatrix
wm Int
l_ Int
r_ Int
k_
      !i' :: Int
i' = Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Maybe Int
goUp RawWaveletMatrix
wm (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) Int
x
   in (Int
i', Int
x)

-- | \(O(\log |S|)\) Looks up the maximum \(y\) in \([l, r) \times (-\infty, y_0]\).
--
-- @since 1.1.0.0
{-# INLINE lookupLE #-}
lookupLE ::
  -- | A wavelet matrix
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(y_0\)
  Int ->
  -- | Maximum \(y\) in \([l, r) \times (-\infty, y_0]\)
  Maybe Int
lookupLE :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupLE RawWaveletMatrix
wm Int
l Int
r Int
x
  | Int
r' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l' = Maybe Int
forall a. Maybe a
Nothing
  | Int
rank_ Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Maybe Int
forall a. Maybe a
Nothing
  | Bool
otherwise = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn RawWaveletMatrix
wm Int
l' Int
r' (Int
rank_ Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    -- clamp
    l' :: Int
l' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l
    r' :: Int
r' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm) Int
r
    rank_ :: Int
rank_ = RawWaveletMatrix -> Int -> Int -> Int -> Int -> Int
rankBetween RawWaveletMatrix
wm Int
l Int
r Int
forall a. Bounded a => a
minBound (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | \(O(\log a)\) Finds the maximum \(x\) in \([l, r)\) s.t. \(x_{0} \lt x\).
--
-- @since 1.1.0.0
{-# INLINE lookupLT #-}
lookupLT ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(x\)
  Int ->
  -- | Maximum \(y\) in \([l, r) \times (-\infty, y_0)\)
  Maybe Int
lookupLT :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupLT RawWaveletMatrix
wm Int
l Int
r Int
x = RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupLE RawWaveletMatrix
wm Int
l Int
r (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | \(O(\log |S|)\) Looks up the minimum \(y\) in \([l, r) \times [y_0, \infty)\).
--
-- @since 1.1.0.0
{-# INLINE lookupGE #-}
lookupGE ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(y_0\)
  Int ->
  -- | Minimum \(y\) in \([l, r) \times [y_0, \infty)\).
  Maybe Int
lookupGE :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupGE RawWaveletMatrix
wm Int
l Int
r Int
x
  | Int
r' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l' = Maybe Int
forall a. Maybe a
Nothing
  | Int
rank_ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l' = Maybe Int
forall a. Maybe a
Nothing
  | Bool
otherwise =
      Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ RawWaveletMatrix -> Int -> Int -> Int -> Int
unsafeKthSmallestIn RawWaveletMatrix
wm Int
l' Int
r' Int
rank_
  where
    -- clamp
    l' :: Int
l' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l
    r' :: Int
r' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (RawWaveletMatrix -> Int
lengthRwm RawWaveletMatrix
wm) Int
r
    rank_ :: Int
rank_ = RawWaveletMatrix -> Int -> Int -> Int -> Int -> Int
rankBetween RawWaveletMatrix
wm Int
l' Int
r' Int
forall a. Bounded a => a
minBound Int
x

-- | \(O(\log |S|)\) Looks up the minimum \(y\) in \([l, r) \times (y_0, \infty)\).
--
-- @since 1.1.0.0
{-# INLINE lookupGT #-}
lookupGT ::
  RawWaveletMatrix ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(y_0\)
  Int ->
  -- | Minimum \(y\) in \([l, r) \times (y_0, \infty)\)
  Maybe Int
lookupGT :: RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupGT RawWaveletMatrix
wm Int
l Int
r Int
x = RawWaveletMatrix -> Int -> Int -> Int -> Maybe Int
lookupGE RawWaveletMatrix
wm Int
l Int
r (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | \(O(\min(|S|, L) \log |S|)\) Collects \((y, \mathrm{rank}(y))\) in range \([l, r)\) in
-- ascending order of \(y\). Note that it's only fast when the \(|S|\) is very small.
--
-- @since 1.1.0.0
{-# INLINE assocsIn #-}
assocsIn :: RawWaveletMatrix -> Int -> Int -> [(Int, Int)]
assocsIn :: RawWaveletMatrix -> Int -> Int -> [(Int, Int)]
assocsIn RawWaveletMatrix
wm Int
l Int
r = RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
assocsWith RawWaveletMatrix
wm Int
l Int
r Int -> Int
forall a. a -> a
id

-- | \(O(\log A \min(|A|, L))\) Internal implementation of `assocs`.
--
-- @since 1.1.0.0
{-# INLINE assocsWith #-}
assocsWith :: RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
assocsWith :: RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
assocsWith RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
l_ Int
r_ Int -> Int
f
  | Int
l'_ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r'_ = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner (Int
0 :: Int) (Int
0 :: Int) Int
l'_ Int
r'_ []
  | Bool
otherwise = []
  where
    -- clamp
    l'_ :: Int
l'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l_
    r'_ :: Int
r'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lengthRwm Int
r_
    -- DFS. [l, r)
    inner :: Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner !Int
acc Int
iRow !Int
l !Int
r [(Int, Int)]
res
      | Int
iRow Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
heightRwm =
          let !n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
              !acc' :: Int
acc' = Int -> Int
f Int
acc
           in (Int
acc', Int
n) (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
: [(Int, Int)]
res
      | Bool
otherwise = do
          let !bits :: BitVector
bits = Vector BitVector
bitsRwm Vector BitVector -> Int -> BitVector
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow
              !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
              !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r
              !nZeros :: Int
nZeros = Vector Int
nZerosRwm Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow
              -- go right (visit bigger values first)
              !l' :: Int
l' = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0
              !r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0
              !res' :: [(Int, Int)]
res'
                | Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r' = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner (Int
acc Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int -> Int
forall a. Bits a => Int -> a
bit (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)) (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
l' Int
r' [(Int, Int)]
res
                | Bool
otherwise = [(Int, Int)]
res
              !res'' :: [(Int, Int)]
res''
                -- go left
                | Int
l0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r0 = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner Int
acc (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
l0 Int
r0 [(Int, Int)]
res'
                | Bool
otherwise = [(Int, Int)]
res'
           in [(Int, Int)]
res''

-- | \(O(\min(|S|, L) \log |S|)\) Collects \((y, \mathrm{rank}(y))\) in range \([l, r)\) in
-- descending order of \(y\). Note that it's only fast when the \(|S|\) is very small.
--
-- @since 1.1.0.0
{-# INLINE descAssocsIn #-}
descAssocsIn :: RawWaveletMatrix -> Int -> Int -> [(Int, Int)]
descAssocsIn :: RawWaveletMatrix -> Int -> Int -> [(Int, Int)]
descAssocsIn RawWaveletMatrix
wm Int
l Int
r = RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
descAssocsInWith RawWaveletMatrix
wm Int
l Int
r Int -> Int
forall a. a -> a
id

-- | \(O(\log A \min(|A|, L))\) Internal implementation of `descAssoc`.
--
-- @since 1.1.0.0
{-# INLINE descAssocsInWith #-}
descAssocsInWith :: RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
descAssocsInWith :: RawWaveletMatrix -> Int -> Int -> (Int -> Int) -> [(Int, Int)]
descAssocsInWith RawWaveletMatrix {Int
Vector Int
Vector BitVector
heightRwm :: RawWaveletMatrix -> Int
lengthRwm :: RawWaveletMatrix -> Int
bitsRwm :: RawWaveletMatrix -> Vector BitVector
nZerosRwm :: RawWaveletMatrix -> Vector Int
heightRwm :: Int
lengthRwm :: Int
bitsRwm :: Vector BitVector
nZerosRwm :: Vector Int
..} Int
l_ Int
r_ Int -> Int
f
  | Int
l'_ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r'_ = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner (Int
0 :: Int) (Int
0 :: Int) Int
l'_ Int
r'_ []
  | Bool
otherwise = []
  where
    -- clamp
    l'_ :: Int
l'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 Int
l_
    r'_ :: Int
r'_ = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lengthRwm Int
r_
    -- DFS. [l, r)
    inner :: Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner !Int
acc Int
iRow !Int
l !Int
r [(Int, Int)]
res
      | Int
iRow Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
heightRwm =
          let !n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
              !acc' :: Int
acc' = Int -> Int
f Int
acc
           in (Int
acc', Int
n) (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
: [(Int, Int)]
res
      | Bool
otherwise = do
          let !bits :: BitVector
bits = Vector BitVector
bitsRwm Vector BitVector -> Int -> BitVector
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow
              !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
              !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r
              !nZeros :: Int
nZeros = Vector Int
nZerosRwm Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow
              !res' :: [(Int, Int)]
res'
                -- go left
                | Int
l0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r0 = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner Int
acc (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
l0 Int
r0 [(Int, Int)]
res
                | Bool
otherwise = [(Int, Int)]
res
              -- go right (visit bigger values first)
              !l' :: Int
l' = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0
              !r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nZeros Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0
              !res'' :: [(Int, Int)]
res''
                | Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r' = Int -> Int -> Int -> Int -> [(Int, Int)] -> [(Int, Int)]
inner (Int
acc Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int -> Int
forall a. Bits a => Int -> a
bit (Int
heightRwm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)) (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
l' Int
r' [(Int, Int)]
res'
                | Bool
otherwise = [(Int, Int)]
res'
           in [(Int, Int)]
res''