{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Scheduler
(
Scheduler
, numWorkers
, scheduleWork
, scheduleWork_
, terminate
, terminate_
, terminateWith
, withScheduler
, withScheduler_
, trivialScheduler_
, Comp(..)
, getCompWorkers
, replicateConcurrently
, replicateConcurrently_
, traverseConcurrently
, traverseConcurrently_
, traverse_
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import Control.Monad.IO.Unlift
import Control.Scheduler.Computation
import Control.Scheduler.Queue
import Data.Atomics (atomicModifyIORefCAS, atomicModifyIORefCAS_)
import qualified Data.Foldable as F (foldl', traverse_)
import Data.IORef
import Data.Traversable
import Data.Maybe (catMaybes)
data Jobs m a = Jobs
{ jobsNumWorkers :: {-# UNPACK #-} !Int
, jobsQueue :: !(JQueue m a)
, jobsCountRef :: !(IORef Int)
}
data Scheduler m a = Scheduler
{ numWorkers :: {-# UNPACK #-} !Int
, scheduleWork :: m a -> m ()
, terminate :: a -> m a
, terminateWith :: a -> m a
}
scheduleWork_ :: Scheduler m () -> m () -> m ()
scheduleWork_ = scheduleWork
terminate_ :: Scheduler m () -> m ()
terminate_ = (`terminateWith` ())
trivialScheduler_ :: Applicative f => Scheduler f ()
trivialScheduler_ = Scheduler
{ numWorkers = 1
, scheduleWork = id
, terminate = const $ pure ()
, terminateWith = const $ pure ()
}
data SchedulerOutcome a
= SchedulerFinished
| SchedulerTerminatedEarly ![a]
| SchedulerWorkerException WorkerException
traverse_ :: (Applicative f, Foldable t) => (a -> f ()) -> t a -> f ()
traverse_ f = F.foldl' (\c a -> c *> f a) (pure ())
traverseConcurrently :: (MonadUnliftIO m, Traversable t) => Comp -> (a -> m b) -> t a -> m (t b)
traverseConcurrently comp f xs = do
ys <- withScheduler comp $ \s -> traverse_ (scheduleWork s . f) xs
pure $ transList ys xs
transList :: Traversable t => [a] -> t b -> t a
transList xs' = snd . mapAccumL withR xs'
where
withR (x:xs) _ = (xs, x)
withR _ _ = error "Impossible<traverseConcurrently> - Mismatched sizes"
traverseConcurrently_ :: (MonadUnliftIO m, Foldable t) => Comp -> (a -> m b) -> t a -> m ()
traverseConcurrently_ comp f xs =
withScheduler_ comp $ \s -> scheduleWork s $ F.traverse_ (scheduleWork s . void . f) xs
replicateConcurrently :: MonadUnliftIO m => Comp -> Int -> m a -> m [a]
replicateConcurrently comp n f =
withScheduler comp $ \s -> replicateM_ n $ scheduleWork s f
replicateConcurrently_ :: MonadUnliftIO m => Comp -> Int -> m a -> m ()
replicateConcurrently_ comp n f =
withScheduler_ comp $ \s -> scheduleWork s $ replicateM_ n (scheduleWork s $ void f)
scheduleJobs :: MonadIO m => Jobs m a -> m a -> m ()
scheduleJobs = scheduleJobsWith mkJob
scheduleJobs_ :: MonadIO m => Jobs m a -> m b -> m ()
scheduleJobs_ = scheduleJobsWith (return . Job_ . void)
scheduleJobsWith :: MonadIO m => (m b -> m (Job m a)) -> Jobs m a -> m b -> m ()
scheduleJobsWith mkJob' jobs action = do
liftIO $ atomicModifyIORefCAS_ (jobsCountRef jobs) (+ 1)
job <-
mkJob' $ do
res <- action
res `seq`
dropCounterOnZero (jobsCountRef jobs) $
retireWorkersN (jobsQueue jobs) (jobsNumWorkers jobs)
return res
pushJQueue (jobsQueue jobs) job
retireWorkersN :: MonadIO m => JQueue m a -> Int -> m ()
retireWorkersN jobsQueue n = traverse_ (pushJQueue jobsQueue) $ replicate n Retire
dropCounterOnZero :: MonadIO m => IORef Int -> m () -> m ()
dropCounterOnZero counterRef onZero = do
jc <-
liftIO $
atomicModifyIORefCAS
counterRef
(\ !i' ->
let !i = i' - 1
in (i, i))
when (jc == 0) onZero
runWorker :: MonadIO m =>
JQueue m a
-> m ()
-> m ()
runWorker jQueue onRetire = go
where
go =
popJQueue jQueue >>= \case
Just job -> job >> go
Nothing -> onRetire
withScheduler ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m [a]
withScheduler comp = withSchedulerInternal comp scheduleJobs readResults reverse
withScheduler_ ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m ()
withScheduler_ comp = void . withSchedulerInternal comp scheduleJobs_ (const (pure [])) id
withSchedulerInternal ::
MonadUnliftIO m
=> Comp
-> (Jobs m a -> m a -> m ())
-> (JQueue m a -> m [Maybe a])
-> ([a] -> [a])
-> (Scheduler m a -> m b)
-> m [a]
withSchedulerInternal comp submitWork collect adjust onScheduler = do
jobsNumWorkers <- getCompWorkers comp
sWorkersCounterRef <- liftIO $ newIORef jobsNumWorkers
jobsQueue <- newJQueue
jobsCountRef <- liftIO $ newIORef 0
workDoneMVar <- liftIO newEmptyMVar
let jobs = Jobs {..}
scheduler =
Scheduler
{ numWorkers = jobsNumWorkers
, scheduleWork = submitWork jobs
, terminate =
\a -> do
mas <- collect jobsQueue
let as = adjust (a : catMaybes mas)
liftIO $ void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly as
pure a
, terminateWith =
\a -> do
liftIO $ void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly [a]
pure a
}
onRetire =
dropCounterOnZero sWorkersCounterRef $
void $ liftIO (tryPutMVar workDoneMVar SchedulerFinished)
_ <- onScheduler scheduler
jc <- liftIO $ readIORef jobsCountRef
when (jc == 0) $ scheduleJobs_ jobs (pure ())
let spawnWorkersWith fork ws =
withRunInIO $ \run ->
forM ws $ \w ->
fork w $ \unmask ->
catch
(unmask $ run $ runWorker jobsQueue onRetire)
(run . handleWorkerException jobsQueue workDoneMVar jobsNumWorkers)
spawnWorkers =
case comp of
Seq -> return []
Par -> spawnWorkersWith forkOnWithUnmask [1 .. jobsNumWorkers]
ParOn ws -> spawnWorkersWith forkOnWithUnmask ws
ParN _ -> spawnWorkersWith (\_ -> forkIOWithUnmask) [1 .. jobsNumWorkers]
terminateWorkers = liftIO . traverse_ (`throwTo` SomeAsyncException WorkerTerminateException)
doWork tids = do
when (comp == Seq) $ runWorker jobsQueue onRetire
mExc <- liftIO $ readMVar workDoneMVar
case mExc of
SchedulerFinished -> adjust . catMaybes <$> collect jobsQueue
SchedulerTerminatedEarly as -> terminateWorkers tids >> pure as
SchedulerWorkerException (WorkerException exc) -> liftIO $ throwIO exc
safeBracketOnError spawnWorkers terminateWorkers doWork
handleWorkerException ::
MonadIO m => JQueue m a -> MVar (SchedulerOutcome a) -> Int -> SomeException -> m ()
handleWorkerException jQueue workDoneMVar nWorkers exc =
case asyncExceptionFromException exc of
Just WorkerTerminateException -> return ()
_ -> do
_ <- liftIO $ tryPutMVar workDoneMVar $ SchedulerWorkerException $ WorkerException exc
retireWorkersN jQueue (nWorkers - 1)
newtype WorkerException =
WorkerException SomeException
deriving (Show)
instance Exception WorkerException
data WorkerTerminateException =
WorkerTerminateException
deriving (Show)
instance Exception WorkerTerminateException
safeBracketOnError :: MonadUnliftIO m => m a -> (a -> m b) -> (a -> m c) -> m c
safeBracketOnError before after thing = withRunInIO $ \run -> mask $ \restore -> do
x <- run before
res1 <- try $ restore $ run $ thing x
case res1 of
Left (e1 :: SomeException) -> do
_ :: Either SomeException b <-
try $ uninterruptibleMask_ $ run $ after x
throwIO e1
Right y -> return y