{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
module UnliftIO.Internals.Async where
import Control.Applicative
import Control.Concurrent (threadDelay, getNumCapabilities)
import qualified Control.Concurrent as C
import Control.Concurrent.Async (Async)
import qualified Control.Concurrent.Async as A
import Control.Concurrent.STM
import Control.Exception (Exception, SomeException)
import Control.Monad (forever, liftM, unless, void, (>=>))
import Control.Monad.IO.Unlift
import Data.Foldable (for_, traverse_)
import Data.Typeable (Typeable)
import Data.IORef (IORef, readIORef, atomicWriteIORef, newIORef, atomicModifyIORef')
import qualified UnliftIO.Exception as UE
import qualified Control.Exception as E
import GHC.Generics (Generic)
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup
#else
import Data.Monoid hiding (Alt)
#endif
import Data.Foldable (Foldable, toList)
import Data.Traversable (Traversable, for, traverse)
async :: MonadUnliftIO m => m a -> m (Async a)
async m = withRunInIO $ \run -> A.async $ run m
asyncBound :: MonadUnliftIO m => m a -> m (Async a)
asyncBound m = withRunInIO $ \run -> A.asyncBound $ run m
asyncOn :: MonadUnliftIO m => Int -> m a -> m (Async a)
asyncOn i m = withRunInIO $ \run -> A.asyncOn i $ run m
asyncWithUnmask :: MonadUnliftIO m => ((forall b. m b -> m b) -> m a) -> m (Async a)
asyncWithUnmask m =
withRunInIO $ \run -> A.asyncWithUnmask $ \unmask -> run $ m $ liftIO . unmask . run
asyncOnWithUnmask :: MonadUnliftIO m => Int -> ((forall b. m b -> m b) -> m a) -> m (Async a)
asyncOnWithUnmask i m =
withRunInIO $ \run -> A.asyncOnWithUnmask i $ \unmask -> run $ m $ liftIO . unmask . run
withAsync :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b
withAsync a b = withRunInIO $ \run -> A.withAsync (run a) (run . b)
withAsyncBound :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b
withAsyncBound a b = withRunInIO $ \run -> A.withAsyncBound (run a) (run . b)
withAsyncOn :: MonadUnliftIO m => Int -> m a -> (Async a -> m b) -> m b
withAsyncOn i a b = withRunInIO $ \run -> A.withAsyncOn i (run a) (run . b)
withAsyncWithUnmask
:: MonadUnliftIO m
=> ((forall c. m c -> m c) -> m a)
-> (Async a -> m b)
-> m b
withAsyncWithUnmask a b =
withRunInIO $ \run -> A.withAsyncWithUnmask
(\unmask -> run $ a $ liftIO . unmask . run)
(run . b)
withAsyncOnWithUnmask
:: MonadUnliftIO m
=> Int
-> ((forall c. m c -> m c) -> m a)
-> (Async a -> m b)
-> m b
withAsyncOnWithUnmask i a b =
withRunInIO $ \run -> A.withAsyncOnWithUnmask i
(\unmask -> run $ a $ liftIO . unmask . run)
(run . b)
wait :: MonadIO m => Async a -> m a
wait = liftIO . A.wait
poll :: MonadIO m => Async a -> m (Maybe (Either SomeException a))
poll = liftIO . A.poll
waitCatch :: MonadIO m => Async a -> m (Either SomeException a)
waitCatch = liftIO . A.waitCatch
cancel :: MonadIO m => Async a -> m ()
cancel = liftIO . A.cancel
uninterruptibleCancel :: MonadIO m => Async a -> m ()
uninterruptibleCancel = liftIO . A.uninterruptibleCancel
cancelWith :: (Exception e, MonadIO m) => Async a -> e -> m ()
cancelWith a e = liftIO (A.cancelWith a (UE.toAsyncException e))
waitAny :: MonadIO m => [Async a] -> m (Async a, a)
waitAny = liftIO . A.waitAny
waitAnyCatch :: MonadIO m => [Async a] -> m (Async a, Either SomeException a)
waitAnyCatch = liftIO . A.waitAnyCatch
waitAnyCancel :: MonadIO m => [Async a] -> m (Async a, a)
waitAnyCancel = liftIO . A.waitAnyCancel
waitAnyCatchCancel :: MonadIO m => [Async a] -> m (Async a, Either SomeException a)
waitAnyCatchCancel = liftIO . A.waitAnyCatchCancel
waitEither :: MonadIO m => Async a -> Async b -> m (Either a b)
waitEither a b = liftIO (A.waitEither a b)
waitEitherCatch :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch a b = liftIO (A.waitEitherCatch a b)
waitEitherCancel :: MonadIO m => Async a -> Async b -> m (Either a b)
waitEitherCancel a b = liftIO (A.waitEitherCancel a b)
waitEitherCatchCancel :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchCancel a b = liftIO (A.waitEitherCatchCancel a b)
waitEither_ :: MonadIO m => Async a -> Async b -> m ()
waitEither_ a b = liftIO (A.waitEither_ a b)
waitBoth :: MonadIO m => Async a -> Async b -> m (a, b)
waitBoth a b = liftIO (A.waitBoth a b)
link :: MonadIO m => Async a -> m ()
link = liftIO . A.link
link2 :: MonadIO m => Async a -> Async b -> m ()
link2 a b = liftIO (A.link2 a b)
race :: MonadUnliftIO m => m a -> m b -> m (Either a b)
race a b = withRunInIO $ \run -> A.race (run a) (run b)
race_ :: MonadUnliftIO m => m a -> m b -> m ()
race_ a b = withRunInIO $ \run -> A.race_ (run a) (run b)
concurrently :: MonadUnliftIO m => m a -> m b -> m (a, b)
concurrently a b = withRunInIO $ \run -> A.concurrently (run a) (run b)
concurrently_ :: MonadUnliftIO m => m a -> m b -> m ()
concurrently_ a b = withRunInIO $ \run -> A.concurrently_ (run a) (run b)
newtype Concurrently m a = Concurrently
{ runConcurrently :: m a
}
instance Monad m => Functor (Concurrently m) where
fmap f (Concurrently a) = Concurrently $ liftM f a
instance MonadUnliftIO m => Applicative (Concurrently m) where
pure = Concurrently . return
Concurrently fs <*> Concurrently as =
Concurrently $ liftM (\(f, a) -> f a) (concurrently fs as)
instance MonadUnliftIO m => Alternative (Concurrently m) where
empty = Concurrently $ liftIO (forever (threadDelay maxBound))
Concurrently as <|> Concurrently bs =
Concurrently $ liftM (either id id) (race as bs)
#if MIN_VERSION_base(4,9,0)
instance (MonadUnliftIO m, Semigroup a) => Semigroup (Concurrently m a) where
(<>) = liftA2 (<>)
instance (Semigroup a, Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where
mempty = pure mempty
mappend = (<>)
#else
instance (Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where
mempty = pure mempty
mappend = liftA2 mappend
#endif
forConcurrently :: MonadUnliftIO m => Traversable t => t a -> (a -> m b) -> m (t b)
forConcurrently = flip mapConcurrently
{-# INLINE forConcurrently #-}
forConcurrently_ :: MonadUnliftIO m => Foldable f => f a -> (a -> m b) -> m ()
forConcurrently_ = flip mapConcurrently_
{-# INLINE forConcurrently_ #-}
#if MIN_VERSION_base(4,7,0)
#else
replicateConcurrently :: (Functor m, MonadUnliftIO m) => Int -> m a -> m [a]
#endif
replicateConcurrently cnt m =
case compare cnt 1 of
LT -> pure []
EQ -> (:[]) <$> m
GT -> mapConcurrently id (replicate cnt m)
{-# INLINE replicateConcurrently #-}
#if MIN_VERSION_base(4,7,0)
replicateConcurrently_ :: (Applicative m, MonadUnliftIO m) => Int -> m a -> m ()
#else
replicateConcurrently_ :: (MonadUnliftIO m) => Int -> m a -> m ()
#endif
replicateConcurrently_ cnt m =
case compare cnt 1 of
LT -> pure ()
EQ -> void m
GT -> mapConcurrently_ id (replicate cnt m)
{-# INLINE replicateConcurrently_ #-}
#if MIN_VERSION_base(4,8,0)
mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b)
mapConcurrently f t = withRunInIO $ \run -> runFlat $ traverse
(FlatApp . FlatAction . run . f)
t
{-# INLINE mapConcurrently #-}
mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m ()
mapConcurrently_ f t = withRunInIO $ \run -> runFlat $ traverse_
(FlatApp . FlatAction . run . f)
t
{-# INLINE mapConcurrently_ #-}
data Conc m a where
Action :: m a -> Conc m a
Apply :: Conc m (v -> a) -> Conc m v -> Conc m a
LiftA2 :: (x -> y -> a) -> Conc m x -> Conc m y -> Conc m a
Pure :: a -> Conc m a
Alt :: Conc m a -> Conc m a -> Conc m a
Empty :: Conc m a
deriving instance Functor m => Functor (Conc m)
conc :: m a -> Conc m a
conc = Action
{-# INLINE conc #-}
runConc :: MonadUnliftIO m => Conc m a -> m a
runConc = flatten >=> (liftIO . runFlat)
{-# INLINE runConc #-}
instance MonadUnliftIO m => Applicative (Conc m) where
pure = Pure
{-# INLINE pure #-}
(<*>) = Apply
{-# INLINE (<*>) #-}
#if MIN_VERSION_base(4,11,0)
liftA2 = LiftA2
{-# INLINE liftA2 #-}
#endif
a *> b = LiftA2 (\_ x -> x) a b
{-# INLINE (*>) #-}
instance MonadUnliftIO m => Alternative (Conc m) where
empty = Empty
{-# INLINE empty #-}
(<|>) = Alt
{-# INLINE (<|>) #-}
#if MIN_VERSION_base(4, 11, 0)
instance (MonadUnliftIO m, Semigroup a) => Semigroup (Conc m a) where
(<>) = liftA2 (<>)
{-# INLINE (<>) #-}
#endif
instance (Monoid a, MonadUnliftIO m) => Monoid (Conc m a) where
mempty = pure mempty
{-# INLINE mempty #-}
mappend = liftA2 mappend
{-# INLINE mappend #-}
data Flat a
= FlatApp !(FlatApp a)
| FlatAlt !(FlatApp a) !(FlatApp a) ![FlatApp a]
deriving instance Functor Flat
instance Applicative Flat where
pure = FlatApp . pure
(<*>) f a = FlatApp (FlatLiftA2 id f a)
#if MIN_VERSION_base(4,11,0)
liftA2 f a b = FlatApp (FlatLiftA2 f a b)
#endif
data FlatApp a where
FlatPure :: a -> FlatApp a
FlatAction :: IO a -> FlatApp a
FlatApply :: Flat (v -> a) -> Flat v -> FlatApp a
FlatLiftA2 :: (x -> y -> a) -> Flat x -> Flat y -> FlatApp a
deriving instance Functor FlatApp
instance Applicative FlatApp where
pure = FlatPure
(<*>) mf ma = FlatApply (FlatApp mf) (FlatApp ma)
#if MIN_VERSION_base(4,11,0)
liftA2 f a b = FlatLiftA2 f (FlatApp a) (FlatApp b)
#endif
data ConcException
= EmptyWithNoAlternative
deriving (Generic, Show, Typeable, Eq, Ord)
instance E.Exception ConcException
type DList a = [a] -> [a]
dlistConcat :: DList a -> DList a -> DList a
dlistConcat = (.)
{-# INLINE dlistConcat #-}
dlistCons :: a -> DList a -> DList a
dlistCons a as = dlistSingleton a `dlistConcat` as
{-# INLINE dlistCons #-}
dlistConcatAll :: [DList a] -> DList a
dlistConcatAll = foldr (.) id
{-# INLINE dlistConcatAll #-}
dlistToList :: DList a -> [a]
dlistToList = ($ [])
{-# INLINE dlistToList #-}
dlistSingleton :: a -> DList a
dlistSingleton a = (a:)
{-# INLINE dlistSingleton #-}
dlistEmpty :: DList a
dlistEmpty = id
{-# INLINE dlistEmpty #-}
flatten :: forall m a. MonadUnliftIO m => Conc m a -> m (Flat a)
flatten c0 = withRunInIO $ \run -> do
let both :: forall k. Conc m k -> IO (Flat k)
both Empty = E.throwIO EmptyWithNoAlternative
both (Action m) = pure $ FlatApp $ FlatAction $ run m
both (Apply cf ca) = do
f <- both cf
a <- both ca
pure $ FlatApp $ FlatApply f a
both (LiftA2 f ca cb) = do
a <- both ca
b <- both cb
pure $ FlatApp $ FlatLiftA2 f a b
both (Alt ca cb) = do
a <- alt ca
b <- alt cb
case dlistToList (a `dlistConcat` b) of
[] -> E.throwIO EmptyWithNoAlternative
[x] -> pure $ FlatApp x
x:y:z -> pure $ FlatAlt x y z
both (Pure a) = pure $ FlatApp $ FlatPure a
alt :: forall k. Conc m k -> IO (DList (FlatApp k))
alt Empty = pure dlistEmpty
alt (Apply cf ca) = do
f <- both cf
a <- both ca
pure (dlistSingleton $ FlatApply f a)
alt (Alt ca cb) = do
a <- alt ca
b <- alt cb
pure $ a `dlistConcat` b
alt (Action m) = pure (dlistSingleton $ FlatAction (run m))
alt (LiftA2 f ca cb) = do
a <- both ca
b <- both cb
pure (dlistSingleton $ FlatLiftA2 f a b)
alt (Pure a) = pure (dlistSingleton $ FlatPure a)
both c0
runFlat :: Flat a -> IO a
runFlat (FlatApp (FlatAction io)) = io
runFlat (FlatApp (FlatPure x)) = pure x
runFlat f0 = E.uninterruptibleMask $ \restore -> do
resultCountVar <- newTVarIO 0
let go :: forall a.
TMVar E.SomeException
-> Flat a
-> IO (STM a, DList C.ThreadId)
go _excVar (FlatApp (FlatPure x)) = pure (pure x, dlistEmpty)
go excVar (FlatApp (FlatAction io)) = do
resVar <- newEmptyTMVarIO
tid <- C.forkIOWithUnmask $ \restore1 -> do
res <- E.try $ restore1 io
atomically $ do
modifyTVar' resultCountVar (+ 1)
case res of
Left e -> void $ tryPutTMVar excVar e
Right x -> putTMVar resVar x
pure (readTMVar resVar, dlistSingleton tid)
go excVar (FlatApp (FlatApply cf ca)) = do
(f, tidsf) <- go excVar cf
(a, tidsa) <- go excVar ca
pure (f <*> a, tidsf `dlistConcat` tidsa)
go excVar (FlatApp (FlatLiftA2 f a b)) = do
(a', tidsa) <- go excVar a
(b', tidsb) <- go excVar b
pure (liftA2 f a' b', tidsa `dlistConcat` tidsb)
go excVar0 (FlatAlt x y z) = do
excVar <- newEmptyTMVarIO
resVar <- newEmptyTMVarIO
pairs <- traverse (go excVar . FlatApp) (x:y:z)
let (blockers, workerTids) = unzip pairs
helperTid <- C.forkIOWithUnmask $ \restore1 -> do
eres <- E.try $ restore1 $ atomically $ foldr
(\blocker rest -> (Right <$> blocker) <|> rest)
(Left <$> readTMVar excVar)
blockers
atomically $ do
modifyTVar' resultCountVar (+ 1)
case eres of
Left (_ :: E.SomeException) -> pure ()
Right (Left e) -> void $ tryPutTMVar excVar0 e
Right (Right res) -> putTMVar resVar res
for_ workerTids $ \tids' ->
for_ (dlistToList tids') $ \workerTid -> C.killThread workerTid
pure ( readTMVar resVar
, helperTid `dlistCons` dlistConcatAll workerTids
)
excVar <- newEmptyTMVarIO
(getRes, tids0) <- go excVar f0
let tids = dlistToList tids0
tidCount = length tids
allDone count =
if count > tidCount
then error ("allDone: count ("
<> show count
<> ") should never be greater than tidCount ("
<> show tidCount
<> ")")
else count == tidCount
let autoRetry action =
action `E.catch`
\E.BlockedIndefinitelyOnSTM -> autoRetry action
res <- E.try $ restore $ autoRetry $ atomically $
(Left <$> readTMVar excVar) <|>
(Right <$> getRes)
count0 <- atomically $ readTVar resultCountVar
unless (allDone count0) $ do
for_ tids $ \tid -> C.killThread tid
restore $ atomically $ do
count <- readTVar resultCountVar
check $ allDone count
case res of
Left e -> E.throwIO (e :: E.SomeException)
Right (Left e) -> E.throwIO e
Right (Right x) -> pure x
{-# INLINEABLE runFlat #-}
#else
mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b)
mapConcurrently f t = withRunInIO $ \run -> A.mapConcurrently (run . f) t
{-# INLINE mapConcurrently #-}
mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m ()
mapConcurrently_ f t = withRunInIO $ \run -> A.mapConcurrently_ (run . f) t
{-# INLINE mapConcurrently_ #-}
#endif
pooledMapConcurrentlyN :: (MonadUnliftIO m, Traversable t)
=> Int
-> (a -> m b) -> t a -> m (t b)
pooledMapConcurrentlyN numProcs f xs =
withRunInIO $ \run -> pooledMapConcurrentlyIO numProcs (run . f) xs
pooledMapConcurrently :: (MonadUnliftIO m, Traversable t) => (a -> m b) -> t a -> m (t b)
pooledMapConcurrently f xs = do
withRunInIO $ \run -> do
numProcs <- getNumCapabilities
pooledMapConcurrentlyIO numProcs (run . f) xs
pooledForConcurrentlyN :: (MonadUnliftIO m, Traversable t)
=> Int
-> t a -> (a -> m b) -> m (t b)
pooledForConcurrentlyN numProcs = flip (pooledMapConcurrentlyN numProcs)
pooledForConcurrently :: (MonadUnliftIO m, Traversable t) => t a -> (a -> m b) -> m (t b)
pooledForConcurrently = flip pooledMapConcurrently
pooledMapConcurrentlyIO :: Traversable t => Int -> (a -> IO b) -> t a -> IO (t b)
pooledMapConcurrentlyIO numProcs f xs =
if (numProcs < 1)
then error "pooledMapconcurrentlyIO: number of threads < 1"
else pooledMapConcurrentlyIO' numProcs f xs
pooledConcurrently
:: Int
-> IORef [a]
-> (a -> IO ())
-> IO ()
pooledConcurrently numProcs jobsVar f = do
replicateConcurrently_ numProcs $ do
let loop = do
mbJob :: Maybe a <- atomicModifyIORef' jobsVar $ \x -> case x of
[] -> ([], Nothing)
var : vars -> (vars, Just var)
case mbJob of
Nothing -> return ()
Just x -> do
f x
loop
in loop
pooledMapConcurrentlyIO' ::
Traversable t => Int
-> (a -> IO b)
-> t a
-> IO (t b)
pooledMapConcurrentlyIO' numProcs f xs = do
jobs :: t (a, IORef b) <-
for xs (\x -> (x, ) <$> newIORef (error "pooledMapConcurrentlyIO': empty IORef"))
jobsVar :: IORef [(a, IORef b)] <- newIORef (toList jobs)
pooledConcurrently numProcs jobsVar $ \ (x, outRef) -> f x >>= atomicWriteIORef outRef
for jobs (\(_, outputRef) -> readIORef outputRef)
pooledMapConcurrentlyIO_' ::
Foldable t => Int -> (a -> IO ()) -> t a -> IO ()
pooledMapConcurrentlyIO_' numProcs f jobs = do
jobsVar :: IORef [a] <- newIORef (toList jobs)
pooledConcurrently numProcs jobsVar f
pooledMapConcurrentlyIO_ :: Foldable t => Int -> (a -> IO b) -> t a -> IO ()
pooledMapConcurrentlyIO_ numProcs f xs =
if (numProcs < 1)
then error "pooledMapconcurrentlyIO_: number of threads < 1"
else pooledMapConcurrentlyIO_' numProcs (\x -> f x >> return ()) xs
pooledMapConcurrentlyN_ :: (MonadUnliftIO m, Foldable f)
=> Int
-> (a -> m b) -> f a -> m ()
pooledMapConcurrentlyN_ numProcs f t =
withRunInIO $ \run -> pooledMapConcurrentlyIO_ numProcs (run . f) t
pooledMapConcurrently_ :: (MonadUnliftIO m, Foldable f) => (a -> m b) -> f a -> m ()
pooledMapConcurrently_ f t =
withRunInIO $ \run -> do
numProcs <- getNumCapabilities
pooledMapConcurrentlyIO_ numProcs (run . f) t
pooledForConcurrently_ :: (MonadUnliftIO m, Foldable f) => f a -> (a -> m b) -> m ()
pooledForConcurrently_ = flip pooledMapConcurrently_
pooledForConcurrentlyN_ :: (MonadUnliftIO m, Foldable t)
=> Int
-> t a -> (a -> m b) -> m ()
pooledForConcurrentlyN_ numProcs = flip (pooledMapConcurrentlyN_ numProcs)
pooledReplicateConcurrentlyN :: (MonadUnliftIO m)
=> Int
-> Int
-> m a -> m [a]
pooledReplicateConcurrentlyN numProcs cnt task =
if cnt < 1
then return []
else pooledMapConcurrentlyN numProcs (\_ -> task) [1..cnt]
pooledReplicateConcurrently :: (MonadUnliftIO m)
=> Int
-> m a -> m [a]
pooledReplicateConcurrently cnt task =
if cnt < 1
then return []
else pooledMapConcurrently (\_ -> task) [1..cnt]
pooledReplicateConcurrentlyN_ :: (MonadUnliftIO m)
=> Int
-> Int
-> m a -> m ()
pooledReplicateConcurrentlyN_ numProcs cnt task =
if cnt < 1
then return ()
else pooledMapConcurrentlyN_ numProcs (\_ -> task) [1..cnt]
pooledReplicateConcurrently_ :: (MonadUnliftIO m)
=> Int
-> m a -> m ()
pooledReplicateConcurrently_ cnt task =
if cnt < 1
then return ()
else pooledMapConcurrently_ (\_ -> task) [1..cnt]