{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Safe #-}
module BroadcastChan.Extra
( Action(..)
, BracketOnError(..)
, Handler(..)
, mapHandler
, runParallel
, runParallel_
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<*))
#endif
import Control.Concurrent (ThreadId, forkFinally, mkWeakThreadId, myThreadId)
import Control.Concurrent.MVar
import Control.Concurrent.QSem
import Control.Concurrent.QSemN
import Control.Exception (Exception(..), SomeException(..))
import qualified Control.Exception as Exc
import Control.Monad ((>=>), replicateM, void)
import Control.Monad.IO.Unlift (MonadIO(..))
import Data.Typeable (Typeable)
import System.Mem.Weak (Weak, deRefWeak)
import BroadcastChan.Internal
unsafeWriteBChan :: MonadIO m => BroadcastChan In a -> a -> m ()
unsafeWriteBChan (BChan writeVar) val = liftIO $ do
new_hole <- newEmptyMVar
Exc.mask_ $ do
old_hole <- takeMVar writeVar
item <- tryTakeMVar old_hole
case item of
Nothing -> return ()
Just Closed -> putMVar new_hole Closed
Just _ -> error "unsafeWriteBChan hit an impossible condition!"
putMVar old_hole (ChItem val new_hole)
putMVar writeVar new_hole
{-# INLINE unsafeWriteBChan #-}
data Shutdown = Shutdown deriving (Show, Typeable)
instance Exception Shutdown
data Action
= Drop
| Retry
| Terminate
deriving (Eq, Show)
data Handler m a
= Simple Action
| Handle (a -> SomeException -> m Action)
data BracketOnError m r
= Bracket
{ allocate :: IO [Weak ThreadId]
, cleanup :: [Weak ThreadId] -> IO ()
, action :: m r
}
mapHandler :: (m Action -> n Action) -> Handler m a -> Handler n a
mapHandler _ (Simple act) = Simple act
mapHandler mmorph (Handle f) = Handle $ \a exc -> mmorph (f a exc)
parallelCore
:: forall a m
. MonadIO m
=> Handler IO a
-> Int
-> IO ()
-> (a -> IO ())
-> m (IO [Weak ThreadId], [Weak ThreadId] -> IO (), a -> IO (), m ())
parallelCore hndl threads onDrop f = liftIO $ do
originTid <- myThreadId
inChanIn <- newBroadcastChan
inChanOut <- newBChanListener inChanIn
shutdownSem <- newQSemN 0
endSem <- newQSemN 0
let bufferValue :: a -> IO ()
bufferValue = void . writeBChan inChanIn
simpleHandler :: a -> SomeException -> Action -> IO ()
simpleHandler val exc act = case act of
Drop -> onDrop
Retry -> unsafeWriteBChan inChanIn val
Terminate -> Exc.throwIO exc
handler :: a -> SomeException -> IO ()
handler _ exc | Just Shutdown <- fromException exc = Exc.throwIO exc
handler val exc = case hndl of
Simple a -> simpleHandler val exc a
Handle h -> h val exc >>= simpleHandler val exc
processInput :: IO ()
processInput = do
x <- readBChan inChanOut
case x of
Nothing -> signalQSemN endSem 1
Just a -> do
f a `Exc.catch` handler a
processInput
allocate :: IO [Weak ThreadId]
allocate = liftIO $ do
tids <- replicateM threads . forkFinally processInput $ \exit -> do
signalQSemN shutdownSem 1
case exit of
Left exc
| Just Shutdown <- fromException exc -> return ()
| otherwise ->
Exc.throwTo originTid exc `Exc.catch` shutdownHandler
Right () -> return ()
mapM mkWeakThreadId tids
where
shutdownHandler Shutdown = return ()
cleanup :: [Weak ThreadId] -> IO ()
cleanup threadIds = liftIO . Exc.uninterruptibleMask_ $ do
mapM_ killWeakThread threadIds
waitQSemN shutdownSem threads
wait :: m ()
wait = do
closeBChan inChanIn
liftIO $ waitQSemN endSem threads
return (allocate, cleanup, bufferValue, wait)
where
killWeakThread :: Weak ThreadId -> IO ()
killWeakThread wTid = do
tid <- deRefWeak wTid
case tid of
Nothing -> return ()
Just t -> Exc.throwTo t Shutdown
runParallel
:: forall a b m n r
. (MonadIO m, MonadIO n)
=> Either (b -> n r) (r -> b -> n r)
-> Handler IO a
-> Int
-> (a -> IO b)
-> ((a -> m ()) -> (a -> m (Maybe b)) -> n r)
-> n (BracketOnError n r)
runParallel yielder hndl threads work pipe = do
outChanIn <- newBroadcastChan
outChanOut <- newBChanListener outChanIn
let process :: MonadIO f => a -> f ()
process = liftIO . (work >=> void . writeBChan outChanIn . Just)
notifyDrop :: IO ()
notifyDrop = void $ writeBChan outChanIn Nothing
(allocate, cleanup, bufferValue, wait) <-
parallelCore hndl threads notifyDrop process
let queueAndYield :: a -> m (Maybe b)
queueAndYield x = do
~(Just v) <- liftIO $ readBChan outChanOut <* bufferValue x
return v
finish :: r -> n r
finish r = do
next <- readBChan outChanOut
case next of
Nothing -> return r
Just Nothing -> finish r
Just (Just v) -> foldFun r v >>= finish
action :: n r
action = do
result <- pipe (liftIO . bufferValue) queueAndYield
wait
closeBChan outChanIn
finish result
return Bracket{allocate,cleanup,action}
where
foldFun = case yielder of
Left g -> const g
Right g -> g
runParallel_
:: (MonadIO m, MonadIO n)
=> Handler IO a
-> Int
-> (a -> IO ())
-> ((a -> m ()) -> n r)
-> n (BracketOnError n r)
runParallel_ hndl threads workFun processElems = do
sem <- liftIO $ newQSem threads
let process x = signalQSem sem >> workFun x
(allocate, cleanup, bufferValue, wait) <-
parallelCore hndl threads (return ()) process
let action = do
result <- processElems $ \v -> liftIO $ do
waitQSem sem
bufferValue v
wait
return result
return Bracket{allocate,cleanup,action}