{-# OPTIONS_GHC -Wno-unused-imports #-}
module Data.Vector.Algorithms.Quicksort.Parameterised
( sortInplaceFM
, module E
) where
import Prelude hiding (last, pi)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM
import Data.Vector.Algorithms.FixedSort
import Data.Vector.Algorithms.Heapsort
import Data.Vector.Algorithms.Quicksort.Fork2 as E
import Data.Vector.Algorithms.Quicksort.Median as E
import Control.Monad.ST
{-# INLINABLE sortInplaceFM #-}
sortInplaceFM
:: forall p med x m a v.
(Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a, GM.MVector v a)
=> p
-> med
-> v (PrimState m) a
-> m ()
sortInplaceFM :: forall p med x (m :: * -> *) a (v :: * -> * -> *).
(Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a,
MVector v a) =>
p -> med -> v (PrimState m) a -> m ()
sortInplaceFM !p
p !med
med !v (PrimState m) a
vector = do
!x
releaseToken <- p -> m x
forall a x (m :: * -> *). Fork2 a x m => a -> m x
startWork p
p
() <- Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
0 x
releaseToken v (PrimState m) a
vector
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
!cutoffLen :: Int
cutoffLen = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector
!logLen :: Int
logLen = Int -> Int
binlog2 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector)
!threshold :: Int
threshold = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
logLen
qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
qsortLoop !Int
depth !x
releaseToken !v (PrimState m) a
v
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
17
= Int -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, Ord a, MVector v a) =>
Int -> v (PrimState m) a -> m ()
bitonicSort Int
len v (PrimState m) a
v m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> p -> x -> m ()
forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken
| Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
threshold Bool -> Bool -> Bool
|| if Int
depthDiff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int
len Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depthDiff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoffLen else Bool
False
= v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort v (PrimState m) a
v m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> p -> x -> m ()
forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken
| Bool
otherwise = do
let !last :: Int
last = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
v' :: v (PrimState m) a
v' = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
last v (PrimState m) a
v
MedianResult a
res <- med -> v (PrimState m) a -> m (MedianResult a)
forall a b (m :: * -> *) s (v :: * -> * -> *).
(Median a b m s, MVector v b, Ord b) =>
a -> v s b -> m (MedianResult b)
forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
med -> v (PrimState m) a -> m (MedianResult a)
selectMedian med
med v (PrimState m) a
v'
(!Int
pi', !a
pv) <- case MedianResult a
res of
ExistingValue a
pv Int
pi -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pi Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
last) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
last
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
pv
(!a
xi, !Int
pi') <- a -> Int -> v (PrimState m) a -> m (a, Int)
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd a
pv (Int
last Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) v (PrimState m) a
v
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi' a
pv
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
xi
(Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
pi' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a
pv)
!Int
pi'' <- a -> Int -> v (PrimState m) a -> m Int
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq a
pv Int
pi' v (PrimState m) a
v
let !left :: v (PrimState m) a
left = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
pi' v (PrimState m) a
v
!right :: v (PrimState m) a
right = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
pi'' (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
pi'') v (PrimState m) a
v
!depth' :: Int
depth' = Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
p
-> x
-> Int
-> (x -> v (PrimState m) a -> m ())
-> (x -> v (PrimState m) a -> m ())
-> v (PrimState m) a
-> v (PrimState m) a
-> m ()
forall b d.
(HasLength b, HasLength d) =>
p
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
forall a x (m :: * -> *) b d.
(Fork2 a x m, HasLength b, HasLength d) =>
a
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
fork2
p
p
x
releaseToken
Int
depth
(Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
(Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
v (PrimState m) a
left
v (PrimState m) a
right
where
!len :: Int
len = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
!depthDiff :: Int
depthDiff = Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
logLen
{-# INLINE partitionTwoWaysPivotAtEnd #-}
partitionTwoWaysPivotAtEnd
:: (PrimMonad m, Ord a, GM.MVector v a)
=> a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd !a
pv !Int
lastIdx !v (PrimState m) a
v =
Int -> Int -> m (a, Int)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> Int -> m (a, Int)
go Int
0 Int
lastIdx
where
go :: Int -> Int -> m (a, Int)
go !Int
i !Int
j = do
!(Int
i', a
xi) <- Int -> m (Int, a)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m (Int, a)
goLT Int
i
!(Int
j', a
xj) <- Int -> m (Int, a)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m (Int, a)
goGT Int
j
if Int
i' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j'
then do
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
j' a
xi
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
i' a
xj
Int -> Int -> m (a, Int)
go (Int
i' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else (a, Int) -> m (a, Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
xi, Int
i')
where
goLT :: Int -> m (Int, a)
goLT !Int
k = do
!a
x <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
pv Bool -> Bool -> Bool
&& Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
j
then Int -> m (Int, a)
goLT (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)
goGT :: Int -> m (Int, a)
goGT !Int
k = do
!a
x <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
pv Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k
then Int -> m (Int, a)
goGT (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)
{-# INLINE skipEq #-}
skipEq :: (PrimMonad m, Eq a, GM.MVector v a) => a -> Int -> v (PrimState m) a -> m Int
skipEq :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq !a
x !Int
start !v (PrimState m) a
v = Int -> m Int
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m Int
go Int
start
where
!last :: Int
last = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
go :: Int -> m Int
go !Int
k
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
last
= do
!a
y <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
if a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
then Int -> m Int
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k
| Bool
otherwise
= Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k
{-# INLINE binlog2 #-}
binlog2 :: Int -> Int
binlog2 :: Int -> Int
binlog2 Int
x = Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros Int
x