-- |
-- Module:     Data.Vector.Algorithms.Heapsort
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com

module Data.Vector.Algorithms.Heapsort
  ( heapSort
  ) where

import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM

{-# INLINABLE shiftDown #-}
shiftDown :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> Int -> m ()
shiftDown :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown !v (PrimState m) a
v = Int -> m ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m ()
go
  where
    !end :: Int
end = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
    go :: Int -> m ()
go !Int
p
      | Int
c1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
      = do
        let !c2 :: Int
c2 = Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        a
c1Val <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
c1
        (Int
maxIdx, a
maxVal) <-
          if Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
          then do
            a
c2Val <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
c2
            (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, a) -> m (Int, a)) -> (Int, a) -> m (Int, a)
forall a b. (a -> b) -> a -> b
$ if a
c1Val a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
c2Val then (Int
c1, a
c1Val) else (Int
c2, a
c2Val)
          else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
c1, a
c1Val)
        a
pVal <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
p
        if a
maxVal a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
pVal
        then do
          v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
p a
maxVal
          v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
maxIdx a
pVal
          Int -> m ()
go Int
maxIdx
        else
          () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Bool
otherwise
      = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      where
        !c1 :: Int
c1 = Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

{-# INLINABLE heapify #-}
heapify :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapify :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify !v (PrimState m) a
v =
  Int -> m ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m ()
go (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1)
  where
    go :: Int -> m ()
go Int
0 = v (PrimState m) a -> Int -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v (PrimState m) a
v Int
0
    go Int
n = v (PrimState m) a -> Int -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v (PrimState m) a
v Int
n m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Int -> m ()
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

{-# INLINABLE heapSort #-}
-- | O(N * log(N)) regular heapsort (with 2-way heap, whereas vector-algorithm's is 4-way).
-- Can be used as a standalone sort but main purpose is fallback sort for quicksort.
--
-- Depending on GHC may be good candidate for SPECIALIZE pragma.
heapSort :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapSort :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort !v (PrimState m) a
v = do
  v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify v (PrimState m) a
v
  Int -> m ()
forall {f :: * -> *}.
(PrimState f ~ PrimState m, PrimMonad f) =>
Int -> f ()
go (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v)
  where
    go :: Int -> f ()
go Int
0 = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go Int
n = do
      let !k :: Int
k = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      v (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
GM.unsafeSwap v (PrimState m) a
v (PrimState f) a
v Int
0 Int
k
      v (PrimState f) a -> Int -> f ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown (Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
k v (PrimState m) a
v) Int
0
      Int -> f ()
go Int
k