module Data.Vector.Algorithms.Heapsort
( heapSort
) where
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM
{-# INLINABLE shiftDown #-}
shiftDown :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> Int -> m ()
shiftDown :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown !v (PrimState m) a
v = Int -> m ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m ()
go
where
!end :: Int
end = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
go :: Int -> m ()
go !Int
p
| Int
c1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
= do
let !c2 :: Int
c2 = Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
a
c1Val <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
c1
(Int
maxIdx, a
maxVal) <-
if Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
then do
a
c2Val <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
c2
(Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, a) -> m (Int, a)) -> (Int, a) -> m (Int, a)
forall a b. (a -> b) -> a -> b
$ if a
c1Val a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
c2Val then (Int
c1, a
c1Val) else (Int
c2, a
c2Val)
else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
c1, a
c1Val)
a
pVal <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
p
if a
maxVal a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
pVal
then do
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
p a
maxVal
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
maxIdx a
pVal
Int -> m ()
go Int
maxIdx
else
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| Bool
otherwise
= () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
!c1 :: Int
c1 = Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
{-# INLINABLE heapify #-}
heapify :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapify :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify !v (PrimState m) a
v =
Int -> m ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m ()
go (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1)
where
go :: Int -> m ()
go Int
0 = v (PrimState m) a -> Int -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v (PrimState m) a
v Int
0
go Int
n = v (PrimState m) a -> Int -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v (PrimState m) a
v Int
n m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Int -> m ()
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINABLE heapSort #-}
heapSort :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapSort :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort !v (PrimState m) a
v = do
v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify v (PrimState m) a
v
Int -> m ()
forall {f :: * -> *}.
(PrimState f ~ PrimState m, PrimMonad f) =>
Int -> f ()
go (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v)
where
go :: Int -> f ()
go Int
0 = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go Int
n = do
let !k :: Int
k = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
v (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
GM.unsafeSwap v (PrimState m) a
v (PrimState f) a
v Int
0 Int
k
v (PrimState f) a -> Int -> f ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown (Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
k v (PrimState m) a
v) Int
0
Int -> f ()
go Int
k