-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Fork2
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com
--
-- This module defines how quicksort is parallelised using the 'Fork2' class.

{-# LANGUAGE FunctionalDependencies #-}

module Data.Vector.Algorithms.Quicksort.Fork2
  (

  -- * Main interface
    Fork2(..)

  -- * No parallelisation
  , Sequential(..)

  -- * Parallelisation with threads
  , Parallel
  , mkParallel
  , waitParallel

  -- * Parallelisation with sparks
  , ParStrategies
  , defaultParStrategies
  , setParStrategiesCutoff

  -- * Helpers
  , HasLength
  , getLength
  ) where

import GHC.Conc (par, pseq)

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.ST
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM
import GHC.ST (unsafeInterleaveST)
import System.IO.Unsafe

-- | Parallelization strategy for the quicksort algorithm with
-- single-pivot partitioning. Specifies how to apply a pair of functions
-- to their respective inputs (which will be recursive quicksort calls).
--
-- NB the name @Fork2@ suggests that two threads will be only forked.
--
-- Parameter meaning;
-- - @a@ - the parallelisation we're defining instance for
-- - @x@ - type of tokens that strategy can pass around to track recursive calls
-- - @m@ - monad the strategy operates in. Some strategies only make
--   sense in a particular monad, e.g. parellelisation via 'forkIO'
class Fork2 a x m | a -> x where
  -- | Will get called /only once/ by quicksort when sorting starts,
  -- returns token to be passed around. Other tokens, e.g. for new
  -- spawned threads, are created by the strategy in the corresponding
  -- class instance.
  startWork :: a -> m x
  -- | Will get called by quicksort when it finishes sorting its array. Will receive
  -- previously created token.
  endWork   :: a -> x -> m ()
  fork2
    :: (HasLength b, HasLength d)
    => a                -- ^ Parallelisation algorithm that can carry
                        -- extra info, e.g. for synchronization
    -> x                -- ^ Token for current execution thread,
                        -- will be passed to 'endWork' when done
    -> Int              -- ^ Recursion depth
    -> (x -> b -> m ()) -- ^ One recursive quicksort call
    -> (x -> d -> m ()) -- ^ The other recursive quicksort call
    -> b                -- ^ One of the subarrays after partitioning to be sorted
    -> d                -- ^ The other subarray to be sorted
    -> m ()

-- | Trivial parallelisation strategy that executes everything
-- sequentially in current thread. Good default overall.
data Sequential = Sequential

instance Monad m => Fork2 Sequential () m where
  {-# INLINE startWork #-}
  {-# INLINE endWork   #-}
  {-# INLINE fork2     #-}
  startWork :: Sequential -> m ()
startWork Sequential
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  endWork :: Sequential -> () -> m ()
endWork Sequential
_ ()
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  fork2 :: forall b d.
(HasLength b, HasLength d) =>
Sequential
-> ()
-> Int
-> (() -> b -> m ())
-> (() -> d -> m ())
-> b
-> d
-> m ()
fork2 Sequential
_ ()
tok Int
_ () -> b -> m ()
f () -> d -> m ()
g !b
b !d
d = () -> b -> m ()
f ()
tok b
b m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> d -> m ()
g ()
tok d
d

-- | At most N concurrent jobs will be spawned to evaluate recursive calls after quicksort
-- partitioning.
--
-- Warning: currently not as fast as sparks-based 'ParStrategies'
-- strategy, take care to benchmark before using.
data Parallel = Parallel !Int !(TVar Int)

-- | Make parallelisation strategy with at most @N@ threads.
mkParallel :: Int -> IO Parallel
mkParallel :: Int -> IO Parallel
mkParallel Int
jobs =
  Int -> TVar Int -> Parallel
Parallel Int
jobs (TVar Int -> Parallel) -> IO (TVar Int) -> IO Parallel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0

addPending :: Parallel -> IO ()
addPending :: Parallel -> IO ()
addPending (Parallel Int
_ TVar Int
pending) =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

removePending :: Parallel -> IO ()
removePending :: Parallel -> IO ()
removePending (Parallel Int
_ TVar Int
pending) =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

-- | Wait until all threads related to a particular 'Parallel' instance finish.
waitParallel :: Parallel -> IO ()
waitParallel :: Parallel -> IO ()
waitParallel (Parallel Int
_ TVar Int
pending) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Int
m <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
pending
  if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
  then () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  else STM ()
forall a. STM a
retry

instance Fork2 Parallel (Bool, Bool) IO where
  {-# INLINE startWork #-}
  {-# INLINE endWork   #-}
  {-# INLINE fork2     #-}
  startWork :: Parallel -> IO (Bool, Bool)
startWork !Parallel
p = do
    Parallel -> IO ()
addPending Parallel
p
    (Bool, Bool) -> IO (Bool, Bool)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, Bool
True)

  endWork :: Parallel -> (Bool, Bool) -> IO ()
endWork Parallel
p (Bool
_, Bool
shouldDecrement)
    | Bool
shouldDecrement
    = Parallel -> IO ()
removePending Parallel
p
    | Bool
otherwise
    = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  fork2
    :: forall b d. (HasLength b, HasLength d)
    => Parallel
    -> (Bool, Bool)
    -> Int
    -> ((Bool, Bool) -> b -> IO ())
    -> ((Bool, Bool) -> d -> IO ())
    -> b
    -> d
    -> IO ()
  fork2 :: forall b d.
(HasLength b, HasLength d) =>
Parallel
-> (Bool, Bool)
-> Int
-> ((Bool, Bool) -> b -> IO ())
-> ((Bool, Bool) -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !p :: Parallel
p@(Parallel Int
jobs TVar Int
_) tok :: (Bool, Bool)
tok@(!Bool
isSeq, Bool
shouldDecrement) !Int
depth (Bool, Bool) -> b -> IO ()
f (Bool, Bool) -> d -> IO ()
g !b
b !d
d
    | Bool
isSeq
    = (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
False) b
b IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
    | Int
2 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
jobs Bool -> Bool -> Bool
&& Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10_000
    = do
      Parallel -> IO ()
addPending Parallel
p
      ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
True) b
b
      (Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
    | Int
bLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
dLen
    = (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
False) b
b IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool
True, Bool
shouldDecrement) d
d
    | Bool
otherwise
    = (Bool, Bool) -> d -> IO ()
g (Bool
False, Bool
False) d
d IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
shouldDecrement) b
b
    where
      bLen, dLen :: Int
      !bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
      !dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d

      !mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen

-- | Parallelise with sparks. After partitioning, if sides are
-- sufficiently big then spark will be created to evaluate one of the
-- parts while another will continue to be evaluated in current
-- execution thread.
--
-- This strategy works in both 'IO' and 'ST' monads (see docs for
-- relevant instance for some discussion on how that works).
--
-- Sparks will seamlessly use all available RTS capabilities
-- (configured with @+RTS -N@ flag) and according to benchmarks in
-- this package have pretty low synchronization overhead as opposed to
-- thread-based parallelisation that 'Parallel' offers. These benefits
-- allow sparks to work on much smaller chunks and exercise more
-- parallelism.
data ParStrategies = ParStrategies !Int

-- | Parallelise with sparks for reasonably big vectors.
defaultParStrategies :: ParStrategies
defaultParStrategies :: ParStrategies
defaultParStrategies = Int -> ParStrategies
ParStrategies Int
10_000

-- | Adjust length of vectors for which parallelisation will be performed.
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff Int
n ParStrategies
_ = Int -> ParStrategies
ParStrategies Int
n

instance Fork2 ParStrategies () IO where
  {-# INLINE startWork #-}
  {-# INLINE endWork   #-}
  {-# INLINE fork2     #-}
  startWork :: ParStrategies -> IO ()
startWork ParStrategies
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  endWork :: ParStrategies -> () -> IO ()
endWork ParStrategies
_ ()
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  fork2
    :: forall b d. (HasLength b, HasLength d)
    => ParStrategies
    -> ()
    -> Int
    -> (() -> b -> IO ())
    -> (() -> d -> IO ())
    -> b
    -> d
    -> IO ()
  fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> IO ())
-> (() -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> IO ()
f () -> d -> IO ()
g !b
b !d
d
    | Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoff
    = do
      let b' :: ()
b' = IO () -> ()
forall a. IO a -> a
unsafePerformIO (IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
$ () -> b -> IO ()
f () b
b
      ()
d' <- ()
b' () -> IO () -> IO ()
forall a b. a -> b -> b
`par` () -> d -> IO ()
g () d
d
      () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
    | Bool
otherwise
    = do
      ()
b' <- () -> b -> IO ()
f () b
b
      ()
d' <- () -> d -> IO ()
g () d
d
      () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
    where
      bLen, dLen :: Int
      !bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
      !dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d

      !mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen

-- | This instance is a bit surprising - ST monad, after all, doesn’t
-- have concurrency and threads everywhere its @s@ parameter to
-- signal, among other things, that it’s single execution thread.
--
-- Still, quicksort in this package hopefully doesn’t do anything
-- funny that may break under parallelism. Use of this instance for
-- other purposes has at least the same caveats as use of
-- 'unsafeInterleaveST' (i.e. not recommended, especially considering
-- that the instance may change).
instance Fork2 ParStrategies () (ST s) where
  {-# INLINE startWork #-}
  {-# INLINE endWork   #-}
  {-# INLINE fork2     #-}
  startWork :: ParStrategies -> ST s ()
startWork ParStrategies
_ = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  endWork :: ParStrategies -> () -> ST s ()
endWork ParStrategies
_ ()
_ = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  fork2
    :: forall b d. (HasLength b, HasLength d)
    => ParStrategies
    -> ()
    -> Int
    -> (() -> b -> ST s ())
    -> (() -> d -> ST s ())
    -> b
    -> d
    -> ST s ()
  fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> ST s ())
-> (() -> d -> ST s ())
-> b
-> d
-> ST s ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> ST s ()
f () -> d -> ST s ()
g !b
b !d
d
    | Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoff
    = do
      ()
b' <- ST s () -> ST s ()
forall s a. ST s a -> ST s a
unsafeInterleaveST (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ () -> b -> ST s ()
f () b
b
      ()
d' <- ()
b' () -> ST s () -> ST s ()
forall a b. a -> b -> b
`par` () -> d -> ST s ()
g () d
d
      () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
    | Bool
otherwise
    = do
      ()
b' <- () -> b -> ST s ()
f () b
b
      ()
d' <- () -> d -> ST s ()
g () d
d
      () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
    where
      bLen, dLen :: Int
      !bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
      !dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d

      !mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen

-- | Helper that can be used to estimatae sizes of subproblems.
--
-- For inscance, too small array will not benefit from sorting it in
-- parallel because parallelisation overhead will likely trump any
-- time savings.
class HasLength a where
  -- | Length of item
  getLength :: a -> Int

instance GM.MVector v a => HasLength (v s a) where
  {-# INLINE getLength #-}
  getLength :: v s a -> Int
getLength = v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length