{-# LANGUAGE RecordWildCards #-}

-- | Minimum binary heap. Mutable and fixed-sized.
--
-- <https://en.wikipedia.org/wiki/Binary_heap>
--
-- ==== __Example__
-- >>> import AtCoder.Internal.MinHeap qualified as MH
-- >>> heap <- MH.new @Int 4
-- >>> MH.capacity heap
-- 4
--
-- >>> MH.push heap 10
-- >>> MH.push heap 0
-- >>> MH.push heap 5
-- >>> MH.length heap -- [0, 5, 10]
-- 3
--
-- >>> MH.pop heap    -- [5, 10]
-- Just 0
--
-- >>> MH.peek heap   -- [5, 10]
-- Just 5
--
-- >>> MH.pop heap    -- [10]
-- Just 5
--
-- >>> MH.clear heap  -- []
-- >>> MH.null heap
-- True
--
-- @since 1.0.0.0
module AtCoder.Internal.MinHeap
  ( -- * Heap
    Heap,

    -- * Constructors
    new,

    -- * Metadata
    capacity,
    length,
    null,

    -- * Reset
    clear,

    -- * Push/pop/peek
    push,
    pop,
    pop_,
    peek,
  )
where

import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (length, null)

-- | Minimum binary heap. Mutable and fixed-sized.
--
-- Indices are zero-based.
--
-- @
--     0
--   1   2
--  3 4 5 6
-- @
--
-- INVARIANT (min heap): child values are bigger than or equal to their parent value.
--
-- @since 1.0.0.0
data Heap s a = Heap
  { -- | Size of the heap.
    forall s a. Heap s a -> MVector s Int
sizeBH_ :: !(VUM.MVector s Int),
    -- | Storage.
    forall s a. Heap s a -> MVector s a
dataBH :: !(VUM.MVector s a)
  }

-- | \(O(n)\) Creates a `Heap` with capacity \(n\).
--
-- @since 1.0.0.0
{-# INLINE new #-}
new :: (VU.Unbox a, PrimMonad m) => Int -> m (Heap (PrimState m) a)
new :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Heap (PrimState m) a)
new Int
n = do
  MVector (PrimState m) Int
sizeBH_ <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
1 Int
0
  MVector (PrimState m) a
dataBH <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
n
  Heap (PrimState m) a -> m (Heap (PrimState m) a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Heap {MVector (PrimState m) a
MVector (PrimState m) Int
sizeBH_ :: MVector (PrimState m) Int
dataBH :: MVector (PrimState m) a
sizeBH_ :: MVector (PrimState m) Int
dataBH :: MVector (PrimState m) a
..}

-- | \(O(1)\) Returns the maximum number of elements in the heap.
--
-- @since 1.0.0.0
{-# INLINE capacity #-}
capacity :: (VU.Unbox a) => Heap s a -> Int
capacity :: forall a s. Unbox a => Heap s a -> Int
capacity = MVector s a -> Int
forall a s. Unbox a => MVector s a -> Int
VUM.length (MVector s a -> Int)
-> (Heap s a -> MVector s a) -> Heap s a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap s a -> MVector s a
forall s a. Heap s a -> MVector s a
dataBH

-- | \(O(1)\) Returns the number of elements in the heap.
--
-- @since 1.0.0.0
{-# INLINE length #-}
length :: (VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m Int
length :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m Int
length Heap {MVector (PrimState m) Int
sizeBH_ :: forall s a. Heap s a -> MVector s Int
sizeBH_ :: MVector (PrimState m) Int
sizeBH_} = MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
sizeBH_ Int
0

-- | \(O(1)\) Returns `True` if the heap is empty.
--
-- @since 1.0.0.0
{-# INLINE null #-}
null :: (VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m Bool
null :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m Bool
null = (Int -> Bool) -> m Int -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
(<$>) (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (m Int -> m Bool)
-> (Heap (PrimState m) a -> m Int)
-> Heap (PrimState m) a
-> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap (PrimState m) a -> m Int
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m Int
length

-- | \(O(1)\) Sets the `length` to zero.
--
-- @since 1.0.0.0
{-# INLINE clear #-}
clear :: (VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m ()
clear :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m ()
clear Heap {MVector (PrimState m) Int
sizeBH_ :: forall s a. Heap s a -> MVector s Int
sizeBH_ :: MVector (PrimState m) Int
sizeBH_} = MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
sizeBH_ Int
0 Int
0

-- | \(O(\log n)\) Inserts an element to the heap.
--
-- @since 1.0.0.0
{-# INLINE push #-}
push :: (HasCallStack, Ord a, VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> a -> m ()
push :: forall a (m :: * -> *).
(HasCallStack, Ord a, Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> a -> m ()
push Heap {MVector (PrimState m) a
MVector (PrimState m) Int
sizeBH_ :: forall s a. Heap s a -> MVector s Int
dataBH :: forall s a. Heap s a -> MVector s a
sizeBH_ :: MVector (PrimState m) Int
dataBH :: MVector (PrimState m) a
..} a
x = do
  Int
i0 <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
sizeBH_ Int
0
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
dataBH Int
i0 a
x
  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
sizeBH_ Int
0 (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
i0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  let siftUp :: Int -> m ()
siftUp Int
i = Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        let iParent :: Int
iParent = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        a
xParent <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataBH Int
iParent
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
xParent) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) a
dataBH Int
iParent Int
i
          Int -> m ()
siftUp Int
iParent
  Int -> m ()
siftUp Int
i0

-- | \(O(\log n)\) Removes the last element from the heap and returns it, or `Nothing` if it is
-- empty.
--
-- @since 1.0.0.0
{-# INLINE pop #-}
pop :: (HasCallStack, Ord a, VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m (Maybe a)
pop :: forall a (m :: * -> *).
(HasCallStack, Ord a, Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m (Maybe a)
pop heap :: Heap (PrimState m) a
heap@Heap {MVector (PrimState m) a
MVector (PrimState m) Int
sizeBH_ :: forall s a. Heap s a -> MVector s Int
dataBH :: forall s a. Heap s a -> MVector s a
sizeBH_ :: MVector (PrimState m) Int
dataBH :: MVector (PrimState m) a
..} = do
  Int
len <- Heap (PrimState m) a -> m Int
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m Int
length Heap (PrimState m) a
heap
  if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
    else do
      let n :: Int
n = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
sizeBH_ Int
0 Int
n
      -- copy the last element to the root
      a
root <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataBH Int
0
      MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) a
dataBH Int
0 Int
n

      -- xl <= xr <= x
      let siftDown :: Int -> m ()
siftDown Int
i = do
            let il :: Int
il = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            let ir :: Int
ir = Int
il Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
il Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              a
x <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataBH Int
i
              a
xl <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataBH Int
il
              if Int
ir Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
                then do
                  -- IMPORTANT: swap with the smaller child
                  a
xr <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
dataBH Int
ir
                  if a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
xr Bool -> Bool -> Bool
&& a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x
                    then do
                      MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) a
dataBH Int
i Int
il
                      Int -> m ()
siftDown Int
il
                    else Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
xr a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                      MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) a
dataBH Int
i Int
ir
                      Int -> m ()
siftDown Int
ir
                else Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector (PrimState m) a
dataBH Int
i Int
il
                  Int -> m ()
siftDown Int
il

      Int -> m ()
siftDown Int
0
      Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ a -> Maybe a
forall a. a -> Maybe a
Just a
root

-- | \(O(\log n)\) `pop` with the return value discarded.
--
-- @since 1.0.0.0
{-# INLINE pop_ #-}
pop_ :: (HasCallStack, Ord a, VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m ()
pop_ :: forall a (m :: * -> *).
(HasCallStack, Ord a, Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m ()
pop_ Heap (PrimState m) a
heap = do
  Maybe a
_ <- Heap (PrimState m) a -> m (Maybe a)
forall a (m :: * -> *).
(HasCallStack, Ord a, Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m (Maybe a)
pop Heap (PrimState m) a
heap
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | \(O(1)\) Returns the smallest value in the heap, or `Nothing` if it is empty.
--
-- @since 1.0.0.0
{-# INLINE peek #-}
peek :: (VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m (Maybe a)
peek :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m (Maybe a)
peek Heap (PrimState m) a
heap = do
  Bool
isNull <- Heap (PrimState m) a -> m Bool
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m Bool
null Heap (PrimState m) a
heap
  if Bool
isNull
    then Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
    else a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read (Heap (PrimState m) a -> MVector (PrimState m) a
forall s a. Heap s a -> MVector s a
dataBH Heap (PrimState m) a
heap) Int
0