{-# LANGUAGE FunctionalDependencies #-}
module Data.Vector.Algorithms.Quicksort.Fork2
(
Fork2(..)
, Sequential(..)
, Parallel
, mkParallel
, waitParallel
, ParStrategies
, defaultParStrategies
, setParStrategiesCutoff
, HasLength
, getLength
) where
import GHC.Conc (par, pseq)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.ST
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM
import GHC.ST (unsafeInterleaveST)
import System.IO.Unsafe
class Fork2 a x m | a -> x where
startWork :: a -> m x
endWork :: a -> x -> m ()
fork2
:: (HasLength b, HasLength d)
=> a
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
data Sequential = Sequential
instance Monad m => Fork2 Sequential () m where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: Sequential -> m ()
startWork Sequential
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: Sequential -> () -> m ()
endWork Sequential
_ ()
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
Sequential
-> ()
-> Int
-> (() -> b -> m ())
-> (() -> d -> m ())
-> b
-> d
-> m ()
fork2 Sequential
_ ()
tok Int
_ () -> b -> m ()
f () -> d -> m ()
g !b
b !d
d = () -> b -> m ()
f ()
tok b
b m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> d -> m ()
g ()
tok d
d
data Parallel = Parallel !Int !(TVar Int)
mkParallel :: Int -> IO Parallel
mkParallel :: Int -> IO Parallel
mkParallel Int
jobs =
Int -> TVar Int -> Parallel
Parallel Int
jobs (TVar Int -> Parallel) -> IO (TVar Int) -> IO Parallel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0
addPending :: Parallel -> IO ()
addPending :: Parallel -> IO ()
addPending (Parallel Int
_ TVar Int
pending) =
STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
removePending :: Parallel -> IO ()
removePending :: Parallel -> IO ()
removePending (Parallel Int
_ TVar Int
pending) =
STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
waitParallel :: Parallel -> IO ()
waitParallel :: Parallel -> IO ()
waitParallel (Parallel Int
_ TVar Int
pending) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Int
m <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
pending
if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
else STM ()
forall a. STM a
retry
instance Fork2 Parallel (Bool, Bool) IO where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: Parallel -> IO (Bool, Bool)
startWork !Parallel
p = do
Parallel -> IO ()
addPending Parallel
p
(Bool, Bool) -> IO (Bool, Bool)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, Bool
True)
endWork :: Parallel -> (Bool, Bool) -> IO ()
endWork Parallel
p (Bool
_, Bool
shouldDecrement)
| Bool
shouldDecrement
= Parallel -> IO ()
removePending Parallel
p
| Bool
otherwise
= () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> Parallel
-> (Bool, Bool)
-> Int
-> ((Bool, Bool) -> b -> IO ())
-> ((Bool, Bool) -> d -> IO ())
-> b
-> d
-> IO ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
Parallel
-> (Bool, Bool)
-> Int
-> ((Bool, Bool) -> b -> IO ())
-> ((Bool, Bool) -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !p :: Parallel
p@(Parallel Int
jobs TVar Int
_) tok :: (Bool, Bool)
tok@(!Bool
isSeq, Bool
shouldDecrement) !Int
depth (Bool, Bool) -> b -> IO ()
f (Bool, Bool) -> d -> IO ()
g !b
b !d
d
| Bool
isSeq
= (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
False) b
b IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
| Int
2 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
jobs Bool -> Bool -> Bool
&& Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10_000
= do
Parallel -> IO ()
addPending Parallel
p
ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
True) b
b
(Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
| Int
bLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
dLen
= (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
False) b
b IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool
True, Bool
shouldDecrement) d
d
| Bool
otherwise
= (Bool, Bool) -> d -> IO ()
g (Bool
False, Bool
False) d
d IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
shouldDecrement) b
b
where
bLen, dLen :: Int
!bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
data ParStrategies = ParStrategies !Int
defaultParStrategies :: ParStrategies
defaultParStrategies :: ParStrategies
defaultParStrategies = Int -> ParStrategies
ParStrategies Int
10_000
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff Int
n ParStrategies
_ = Int -> ParStrategies
ParStrategies Int
n
instance Fork2 ParStrategies () IO where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: ParStrategies -> IO ()
startWork ParStrategies
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: ParStrategies -> () -> IO ()
endWork ParStrategies
_ ()
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> ParStrategies
-> ()
-> Int
-> (() -> b -> IO ())
-> (() -> d -> IO ())
-> b
-> d
-> IO ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> IO ())
-> (() -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> IO ()
f () -> d -> IO ()
g !b
b !d
d
| Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoff
= do
let b' :: ()
b' = IO () -> ()
forall a. IO a -> a
unsafePerformIO (IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
$ () -> b -> IO ()
f () b
b
()
d' <- ()
b' () -> IO () -> IO ()
forall a b. a -> b -> b
`par` () -> d -> IO ()
g () d
d
() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
| Bool
otherwise
= do
()
b' <- () -> b -> IO ()
f () b
b
()
d' <- () -> d -> IO ()
g () d
d
() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
where
bLen, dLen :: Int
!bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
instance Fork2 ParStrategies () (ST s) where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: ParStrategies -> ST s ()
startWork ParStrategies
_ = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: ParStrategies -> () -> ST s ()
endWork ParStrategies
_ ()
_ = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> ParStrategies
-> ()
-> Int
-> (() -> b -> ST s ())
-> (() -> d -> ST s ())
-> b
-> d
-> ST s ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> ST s ())
-> (() -> d -> ST s ())
-> b
-> d
-> ST s ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> ST s ()
f () -> d -> ST s ()
g !b
b !d
d
| Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cutoff
= do
()
b' <- ST s () -> ST s ()
forall s a. ST s a -> ST s a
unsafeInterleaveST (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ () -> b -> ST s ()
f () b
b
()
d' <- ()
b' () -> ST s () -> ST s ()
forall a b. a -> b -> b
`par` () -> d -> ST s ()
g () d
d
() -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
| Bool
otherwise
= do
()
b' <- () -> b -> ST s ()
f () b
b
()
d' <- () -> d -> ST s ()
g () d
d
() -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' () -> () -> ()
forall a b. a -> b -> b
`pseq` (()
d' () -> () -> ()
forall a b. a -> b -> b
`pseq` ()))
where
bLen, dLen :: Int
!bLen :: Int
bLen = b -> Int
forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = d -> Int
forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
class HasLength a where
getLength :: a -> Int
instance GM.MVector v a => HasLength (v s a) where
{-# INLINE getLength #-}
getLength :: v s a -> Int
getLength = v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length