-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Parameterised
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com
--
-- This module provides fully generic quicksort for now allowing
-- caller to decide how to parallelize and how to select median. More
-- things may be parameterised in the future, likely by introducing
-- new functions taking more arguments.
--
-- === Example
-- This is how you’d define parallel sort that uses sparks on unboxed vectors of integers:
--
-- >>> import Control.Monad.ST
-- >>> import Data.Int
-- >>> import Data.Vector.Algorithms.Quicksort.Parameterised
-- >>> import Data.Vector.Unboxed qualified as U
-- >>> :{
-- let myParallelSort :: U.MVector s Int64 -> ST s ()
--     myParallelSort = sortInplaceFM defaultParStrategies (Median3or5 @Int64)
-- in U.modify myParallelSort $ U.fromList @Int64 [20, 19 .. 0]
-- :}
-- [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
--
-- === Design considerations
-- Because of reliance on specialisation, this package doesn't provide
-- sort functions that take comparator function as argument. They rely
-- on the 'Ord' instance instead. While somewhat limiting, this allows
-- to offload optimization to the @SPECIALIZE@ pragmas even if compiler
-- wasn't smart enough to monomorphise automatically.
--
-- === Performance considerations
-- Compared to the default sort this one is even more sensitive to
-- specialisation. Users caring about performance are advised to dump
-- core and ensure that sort is monomorphised. The GHC 9.6.1 was seen
-- to specialize automatically but 9.4 wasn't as good and required
-- pragmas both for the main sort function and for its helpers, like this:
--
-- > -- Either use the flag to specialize everything, ...
-- > {-# OPTIONS_GHC -fspecialise-aggressively #-}
-- >
-- > -- ... or the pragmas for specific functions
-- > import Control.Monad.ST
-- > import Data.Int
-- > import Data.Vector.Algorithms.FixedSort
-- > import Data.Vector.Algorithms.Heapsort
-- > import Data.Vector.Algorithms.Quicksort.Parameterised
-- > import Data.Vector.Unboxed qualified as U
-- >
-- > {-# SPECIALIZE heapSort      :: U.MVector s Int64 -> ST s ()        #-}
-- > {-# SPECIALIZE bitonicSort   :: Int -> U.MVector s Int64 -> ST s () #-}
-- > {-# SPECIALIZE sortInplaceFM :: Sequential -> Median3 Int64 -> U.MVector s Int64 -> ST s () #-}
--
-- === Speeding up compilation
-- In order to speed up compilations it's a good idea to introduce
-- dedicated module where all the sorts will reside and import it
-- instead of calling @sort@ or @sortInplaceFM@ in moduler with other logic.
-- This way the sort functions, which can take a while to compile, will be
-- recompiled rarely.
--
-- > module MySorts (mySequentialSort) where
-- >
-- > import Control.Monad.ST
-- > import Data.Int
-- > import Data.Vector.Unboxed qualified as U
-- >
-- > import Data.Vector.Algorithms.Quicksort.Parameterised
-- >
-- > {-# NOINLINE mySequentialSort #-}
-- > mySequentialSort :: U.MVector s Int64 -> ST s ()
-- > mySequentialSort = sortInplaceFM Sequential (Median3or5 @Int64)
--
-- === Reducing code bloat
-- Avoid using sorts with both 'ST' and 'IO' monads. Stick to the 'ST'
-- monad as much as possible because it can be easily converted to
-- 'IO' via safe 'stToIO' function. Using same sort in both 'IO' and
-- 'ST' monads will compile two versions of it along with all it’s
-- helper sorts which can be pretty big (especially the bitonic sort).

-- So that haddock will resolve references in the documentation.
{-# OPTIONS_GHC -Wno-unused-imports #-}

module Data.Vector.Algorithms.Quicksort.Parameterised
  ( sortInplaceFM
  -- * Reexports
  , module E
  ) where

import Prelude hiding (last, pi)

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM

import Data.Vector.Algorithms.FixedSort
import Data.Vector.Algorithms.Heapsort
import Data.Vector.Algorithms.Quicksort.Fork2 as E
import Data.Vector.Algorithms.Quicksort.Median as E

-- For haddock
import Control.Monad.ST

{-# INLINABLE sortInplaceFM #-}
-- | Quicksort parameterised by median selection method and
-- parallelisation strategy.
sortInplaceFM
  :: forall p med x m a v.
     (Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a, GM.MVector v a)
  => p
  -> med
  -> v (PrimState m) a
  -> m ()
sortInplaceFM :: forall p med x (m :: * -> *) a (v :: * -> * -> *).
(Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a,
 MVector v a) =>
p -> med -> v (PrimState m) a -> m ()
sortInplaceFM !p
p !med
med !v (PrimState m) a
vector = do
  !x
releaseToken <- p -> m x
forall a x (m :: * -> *). Fork2 a x m => a -> m x
startWork p
p
  -- ParStrategies requires forcing the unit, otherwise we may return
  -- while some sparks are still working.
  () <- Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
0 x
releaseToken v (PrimState m) a
vector
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    -- If we select bad median 4 times in a row then fall back to heapsort.
    !cutoffLen :: Int
cutoffLen = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector

    !logLen :: Int
logLen = Int -> Int
binlog2 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector)

    !threshold :: Int
threshold = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
logLen

    qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
    qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
qsortLoop !Int
depth !x
releaseToken !v (PrimState m) a
v
      | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
17
      = Int -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, Ord a, MVector v a) =>
Int -> v (PrimState m) a -> m ()
bitonicSort Int
len v (PrimState m) a
v m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> p -> x -> m ()
forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken

      | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
threshold Bool -> Bool -> Bool
|| if Int
depthDiff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int
len Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depthDiff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoffLen else Bool
False
      = v (PrimState m) a -> m ()
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort v (PrimState m) a
v m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> p -> x -> m ()
forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken

      | Bool
otherwise = do
        let !last :: Int
last = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
            v' :: v (PrimState m) a
v'    = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
last v (PrimState m) a
v
        MedianResult a
res <- med -> v (PrimState m) a -> m (MedianResult a)
forall a b (m :: * -> *) s (v :: * -> * -> *).
(Median a b m s, MVector v b, Ord b) =>
a -> v s b -> m (MedianResult b)
forall (v :: * -> * -> *).
(MVector v a, Ord a) =>
med -> v (PrimState m) a -> m (MedianResult a)
selectMedian med
med v (PrimState m) a
v'

        (!Int
pi', !a
pv) <- case MedianResult a
res of
          ExistingValue a
pv Int
pi -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pi Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
last) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
last
              v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
pv
            (!a
xi, !Int
pi') <- a -> Int -> v (PrimState m) a -> m (a, Int)
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd a
pv (Int
last Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) v (PrimState m) a
v
            v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi' a
pv
            v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
xi
            (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
pi' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a
pv)

        !Int
pi'' <- a -> Int -> v (PrimState m) a -> m Int
forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq a
pv Int
pi' v (PrimState m) a
v

        let !left :: v (PrimState m) a
left   = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
pi' v (PrimState m) a
v
            !right :: v (PrimState m) a
right  = Int -> Int -> v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
pi'' (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
pi'') v (PrimState m) a
v
            !depth' :: Int
depth' = Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        p
-> x
-> Int
-> (x -> v (PrimState m) a -> m ())
-> (x -> v (PrimState m) a -> m ())
-> v (PrimState m) a
-> v (PrimState m) a
-> m ()
forall b d.
(HasLength b, HasLength d) =>
p
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
forall a x (m :: * -> *) b d.
(Fork2 a x m, HasLength b, HasLength d) =>
a
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
fork2
          p
p
          x
releaseToken
          Int
depth
          (Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
          (Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
          v (PrimState m) a
left
          v (PrimState m) a
right
      where
        !len :: Int
len       = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
        !depthDiff :: Int
depthDiff = Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
logLen

{-# INLINE partitionTwoWaysPivotAtEnd #-}
partitionTwoWaysPivotAtEnd
  :: (PrimMonad m, Ord a, GM.MVector v a)
  => a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd !a
pv !Int
lastIdx !v (PrimState m) a
v =
  Int -> Int -> m (a, Int)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> Int -> m (a, Int)
go Int
0 Int
lastIdx
  where
    go :: Int -> Int -> m (a, Int)
go !Int
i !Int
j = do
      !(Int
i', a
xi) <- Int -> m (Int, a)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m (Int, a)
goLT Int
i
      !(Int
j', a
xj) <- Int -> m (Int, a)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m (Int, a)
goGT Int
j
      if Int
i' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j'
      then do
        v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
j' a
xi
        v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v (PrimState m) a
v Int
i' a
xj
        Int -> Int -> m (a, Int)
go (Int
i' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      else (a, Int) -> m (a, Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
xi, Int
i')
      where
        goLT :: Int -> m (Int, a)
goLT !Int
k = do
          !a
x <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
          if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
pv Bool -> Bool -> Bool
&& Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
j
          then Int -> m (Int, a)
goLT (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)
        goGT :: Int -> m (Int, a)
goGT !Int
k = do
          !a
x <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
          if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
pv Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k
          then Int -> m (Int, a)
goGT (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
          else (Int, a) -> m (Int, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)

{-# INLINE skipEq #-}
-- Idetnify multiple pivots that are equal to the one we were partitioning with so that
-- whole run of equal pivots can be excluded from recursion.
skipEq :: (PrimMonad m, Eq a, GM.MVector v a) => a -> Int -> v (PrimState m) a -> m Int
skipEq :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq !a
x !Int
start !v (PrimState m) a
v = Int -> m Int
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> m Int
go Int
start
  where
    !last :: Int
last = v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
    go :: Int -> m Int
go !Int
k
      | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
last
      = do
        !a
y <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v (PrimState m) a
v Int
k
        if a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
        then Int -> m Int
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        else Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k
      | Bool
otherwise
      = Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k

{-# INLINE binlog2 #-}
binlog2 :: Int -> Int
binlog2 :: Int -> Int
binlog2 Int
x = Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros Int
x