{-# LANGUAGE RecordWildCards #-}

-- | A compact bit vector, a collection of bits that can process @rank@ in \(O(1)\) and @select@ in
-- \(O(\log n)\).
--
-- @since 1.1.0.0
module AtCoder.Extra.WaveletMatrix.BitVector
  ( -- * Bit vector
    BitVector (..),

    -- * Constructor
    build,

    -- * (Internal) Word-based cumultaive sum
    wordSize,
    csumInPlace,

    -- * Rank
    rank0,
    rank1,

    -- * Select
    select0,
    select1,
    selectKthIn0,
    selectKthIn1,
  )
where

import AtCoder.Extra.Bisect (bisectL)
import Control.Monad.Primitive (PrimMonad (PrimState))
import Data.Bit (Bit (..))
import Data.Bits (popCount)
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM

-- | A compact bit vector.
--
-- @since 1.1.0.0
data BitVector = BitVector
  { -- | Packed bits.
    --
    -- @since 1.1.0.0
    BitVector -> Vector Bit
bitsBv :: !(VU.Vector Bit),
    -- | Cumulative sum of bits by 64 words.
    --
    -- @since 1.1.0.0
    BitVector -> Vector Int
csumBv :: !(VU.Vector Int)
    -- we could use Word32 for csumBv, as 2^32 is large enough
  }
  deriving (BitVector -> BitVector -> Bool
(BitVector -> BitVector -> Bool)
-> (BitVector -> BitVector -> Bool) -> Eq BitVector
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitVector -> BitVector -> Bool
== :: BitVector -> BitVector -> Bool
$c/= :: BitVector -> BitVector -> Bool
/= :: BitVector -> BitVector -> Bool
Eq, Int -> BitVector -> ShowS
[BitVector] -> ShowS
BitVector -> String
(Int -> BitVector -> ShowS)
-> (BitVector -> String)
-> ([BitVector] -> ShowS)
-> Show BitVector
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BitVector -> ShowS
showsPrec :: Int -> BitVector -> ShowS
$cshow :: BitVector -> String
show :: BitVector -> String
$cshowList :: [BitVector] -> ShowS
showList :: [BitVector] -> ShowS
Show)

-- | \(O(n)\) Creates a `BitVector`.
--
-- @since 1.1.0.0
{-# INLINE build #-}
build :: VU.Vector Bit -> BitVector
build :: Vector Bit -> BitVector
build Vector Bit
bitsBv =
  let csumBv :: Vector Int
csumBv = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
        MVector s Int
vec <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate ((Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Bit
bitsBv Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
63) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
        Int
_ <- MVector (PrimState (ST s)) Int -> Vector Bit -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> Vector Bit -> m Int
csumInPlace MVector s Int
MVector (PrimState (ST s)) Int
vec Vector Bit
bitsBv
        MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
vec
   in BitVector {Vector Int
Vector Bit
bitsBv :: Vector Bit
csumBv :: Vector Int
bitsBv :: Vector Bit
csumBv :: Vector Int
..}

-- | The block size \(64\) for the internal cumultaive sum in the bit vector.
--
-- @since 1.1.0.0
{-# INLINE wordSize #-}
wordSize :: Int
wordSize :: Int
wordSize = Int
64

-- | \(O(n)\) Calculates the cumulative sum in-place for the bit vector and returns the sum.
--
-- @since 1.1.0.0
{-# INLINE csumInPlace #-}
csumInPlace ::
  (PrimMonad m) =>
  -- | Cumulative sum of length \(\lceil |\mathrm{bits}| / 64 \rceil\).
  VUM.MVector (PrimState m) Int ->
  -- | Bits
  VU.Vector Bit ->
  -- | Sum of the bits
  m Int
csumInPlace :: forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> Vector Bit -> m Int
csumInPlace MVector (PrimState m) Int
csum Vector Bit
bits = do
  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
csum Int
0 (Int
0 :: Int)

  -- Calcuate popCount by word. TODO: use `castToWords` for most elements
  (Int -> Int -> Int -> m Int) -> Int -> Vector Int -> m Int
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> Int -> b -> m a) -> a -> Vector b -> m a
VU.ifoldM'
    ( \ !Int
acc Int
i Int
wordSum -> do
        let !acc' :: Int
acc' = Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
wordSum
        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
csum (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
acc'
        Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
acc'
    )
    (Int
0 :: Int)
    (Vector Int -> m Int) -> Vector Int -> m Int
forall a b. (a -> b) -> a -> b
$ Int
-> (Vector Bit -> (Int, Vector Bit)) -> Vector Bit -> Vector Int
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (MVector (PrimState m) Int -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
VGM.length MVector (PrimState m) Int
csum Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      (\Vector Bit
bits' -> (Vector Bit -> Int
forall a. Bits a => a -> Int
popCount (Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Vector a -> Vector a
VU.take Int
wordSize Vector Bit
bits'), Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Vector a -> Vector a
VU.drop Int
wordSize Vector Bit
bits'))
      Vector Bit
bits

-- | \(O(1)\) Counts the number of \(0\) bits in the interval \([0, i)\).
--
-- @since 1.1.0.0
{-# INLINE rank0 #-}
rank0 :: BitVector -> Int -> Int
rank0 :: BitVector -> Int -> Int
rank0 BitVector
bv Int
i = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- BitVector -> Int -> Int
rank1 BitVector
bv Int
i

-- | \(O(1)\) Counts the number of \(1\) bits in an interval \([0, i)\).
--
-- @since 1.1.0.0
{-# INLINE rank1 #-}
rank1 :: BitVector -> Int -> Int
rank1 :: BitVector -> Int -> Int
rank1 BitVector {Vector Int
Vector Bit
bitsBv :: BitVector -> Vector Bit
csumBv :: BitVector -> Vector Int
bitsBv :: Vector Bit
csumBv :: Vector Int
..} Int
i = Int
fromCSum Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
fromRest
  where
    -- TODO: check bounds for i?
    (!Int
nWords, !Int
nRest) = Int
i Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
wordSize
    fromCSum :: Int
fromCSum = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
csumBv Int
nWords
    fromRest :: Int
fromRest = Vector Bit -> Int
forall a. Bits a => a -> Int
popCount (Vector Bit -> Int)
-> (Vector Bit -> Vector Bit) -> Vector Bit -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Vector a -> Vector a
VU.take Int
nRest (Vector Bit -> Vector Bit)
-> (Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Vector a -> Vector a
VU.drop (Int
nWords Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
wordSize) (Vector Bit -> Int) -> Vector Bit -> Int
forall a b. (a -> b) -> a -> b
$ Vector Bit
bitsBv

-- | \(O(\log n)\) Returns the index of \(k\)-th \(0\) (0-based), or `Nothing` if no such bit exists.
--
-- @since 1.1.0.0
{-# INLINE select0 #-}
select0 :: BitVector -> Int -> Maybe Int
select0 :: BitVector -> Int -> Maybe Int
select0 BitVector
bv = BitVector -> Int -> Int -> Int -> Maybe Int
selectKthIn0 BitVector
bv Int
0 (Vector Bit -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length (BitVector -> Vector Bit
bitsBv BitVector
bv))

-- | \(O(\log n)\) Returns the index of \(k\)-th \(1\) (0-based), or `Nothing` if no such bit exists.
--
-- @since 1.1.0.0
{-# INLINE select1 #-}
select1 :: BitVector -> Int -> Maybe Int
select1 :: BitVector -> Int -> Maybe Int
select1 BitVector
bv = BitVector -> Int -> Int -> Int -> Maybe Int
selectKthIn1 BitVector
bv Int
0 (Vector Bit -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length (BitVector -> Vector Bit
bitsBv BitVector
bv))

-- | \(O(\log n)\) Returns the index of \(k\)-th \(0\) (0-based) in \([l, r)\), or `Nothing` if no
-- such bit exists.
--
-- @since 1.1.0.0
{-# INLINE selectKthIn0 #-}
selectKthIn0 ::
  -- | A bit vector
  BitVector ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | The index of \(k\)-th \(0\) in \([l, r)\)
  Maybe Int
selectKthIn0 :: BitVector -> Int -> Int -> Int -> Maybe Int
selectKthIn0 BitVector
bv Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
nZeros Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k = Maybe Int
forall a. Maybe a
Nothing
  | Bool
otherwise = HasCallStack => Int -> Int -> (Int -> Bool) -> Maybe Int
Int -> Int -> (Int -> Bool) -> Maybe Int
bisectL Int
l Int
r ((Int -> Bool) -> Maybe Int) -> (Int -> Bool) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ \Int
i -> BitVector -> Int -> Int
rank0 BitVector
bv Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rankL0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  where
    nZeros :: Int
nZeros = BitVector -> Int -> Int
rank0 BitVector
bv Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rankL0
    rankL0 :: Int
rankL0 = BitVector -> Int -> Int
rank0 BitVector
bv Int
l

-- | \(O(\log n)\) Returns the index of \(k\)-th \(1\) (0-based) in \([l, r)\), or `Nothing` if no
-- such bit exists.
--
-- @since 1.1.0.0
{-# INLINE selectKthIn1 #-}
selectKthIn1 ::
  -- | A bit vector
  BitVector ->
  -- | \(l\)
  Int ->
  -- | \(r\)
  Int ->
  -- | \(k\)
  Int ->
  -- | The index of \(k\)-th \(1\) in \([l, r)\)
  Maybe Int
selectKthIn1 :: BitVector -> Int -> Int -> Int -> Maybe Int
selectKthIn1 BitVector
bv Int
l Int
r Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
nOnes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k = Maybe Int
forall a. Maybe a
Nothing
  | Bool
otherwise = HasCallStack => Int -> Int -> (Int -> Bool) -> Maybe Int
Int -> Int -> (Int -> Bool) -> Maybe Int
bisectL Int
l Int
r ((Int -> Bool) -> Maybe Int) -> (Int -> Bool) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ \Int
i -> BitVector -> Int -> Int
rank1 BitVector
bv Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rankL1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  where
    nOnes :: Int
nOnes = BitVector -> Int -> Int
rank1 BitVector
bv Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rankL1
    rankL1 :: Int
rankL1 = BitVector -> Int -> Int
rank1 BitVector
bv Int
l