{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
module Control.Monad.Exception (
E.Exception(..),
E.SomeException,
MonadException(..),
onException,
MonadAsyncException(..),
bracket,
bracket_,
ExceptionT(..),
mapExceptionT,
liftException
) where
#if !MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif /*!MIN_VERSION_base(4,6,0) */
import Control.Applicative
import qualified Control.Exception as E (Exception(..),
SomeException,
catch,
throw,
finally)
import qualified Control.Exception as E (mask)
import Control.Monad (MonadPlus(..))
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Error (Error(..),
ErrorT(..),
mapErrorT,
runErrorT)
import Control.Monad.Trans.Except (ExceptT(..),
mapExceptT,
runExceptT)
import Control.Monad.Trans.Identity (IdentityT(..),
mapIdentityT,
runIdentityT)
import Control.Monad.Trans.List (ListT(..),
mapListT,
runListT)
import Control.Monad.Trans.Maybe (MaybeT(..),
mapMaybeT,
runMaybeT)
import Control.Monad.Trans.RWS.Lazy as Lazy (RWST(..),
mapRWST,
runRWST)
import Control.Monad.Trans.RWS.Strict as Strict (RWST(..),
mapRWST,
runRWST)
import Control.Monad.Trans.Reader (ReaderT(..),
mapReaderT)
import Control.Monad.Trans.State.Lazy as Lazy (StateT(..),
mapStateT,
runStateT)
import Control.Monad.Trans.State.Strict as Strict (StateT(..),
mapStateT,
runStateT)
import Control.Monad.Trans.Writer.Lazy as Lazy (WriterT(..),
mapWriterT,
runWriterT)
import Control.Monad.Trans.Writer.Strict as Strict (WriterT(..),
mapWriterT,
runWriterT)
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (Monoid)
#endif /* !MIN_VERSION_base(4,8,0) */
import GHC.Conc.Sync (STM(..),
catchSTM,
throwSTM)
class (Monad m) => MonadException m where
throw :: E.Exception e => e -> m a
catch :: E.Exception e
=> m a
-> (e -> m a)
-> m a
finally :: m a
-> m b
-> m a
act `finally` sequel = do
a <- act `onException` sequel
_ <- sequel
return a
onException :: MonadException m
=> m a
-> m b
-> m a
onException act what =
act `catch` \(e :: E.SomeException) -> what >> throw e
class (MonadIO m, MonadException m) => MonadAsyncException m where
mask :: ((forall a. m a -> m a) -> m b) -> m b
bracket :: MonadAsyncException m
=> m a
-> (a -> m b)
-> (a -> m c)
-> m c
bracket before after thing =
mask $ \restore -> do
a <- before
restore (thing a) `finally` after a
bracket_ :: MonadAsyncException m
=> m a
-> m b
-> m c
-> m c
bracket_ before after thing =
bracket before (const after) (const thing)
newtype ExceptionT m a =
ExceptionT { runExceptionT :: m (Either E.SomeException a) }
mapExceptionT :: (m (Either E.SomeException a) -> n (Either E.SomeException b))
-> ExceptionT m a
-> ExceptionT n b
mapExceptionT f = ExceptionT . f . runExceptionT
liftException :: MonadException m => Either E.SomeException a -> m a
liftException (Left e) = throw e
liftException (Right a) = return a
instance MonadTrans ExceptionT where
lift m = ExceptionT $ do
a <- m
return (Right a)
instance (Functor m, Monad m) => Applicative (ExceptionT m) where
pure a = ExceptionT $ return (Right a)
f <*> v = ExceptionT $ do
mf <- runExceptionT f
case mf of
Left e -> return (Left e)
Right k -> do
mv <- runExceptionT v
case mv of
Left e -> return (Left e)
Right x -> return (Right (k x))
instance (Functor m) => Functor (ExceptionT m) where
fmap f = ExceptionT . fmap (fmap f) . runExceptionT
instance (Monad m) => Monad (ExceptionT m) where
#if MIN_VERSION_base(4,8,0)
return = pure
#else /* !MIN_VERSION_base(4,8,0) */
return a = ExceptionT $ return (Right a)
#endif /* !MIN_VERSION_base(4,8,0) */
m >>= k = ExceptionT $ do
a <- runExceptionT m
case a of
Left l -> return (Left l)
Right r -> runExceptionT (k r)
#if MIN_VERSION_base(4,13,0)
instance (Monad m) => MonadFail (ExceptionT m) where
#endif
fail msg = ExceptionT $ return (Left (E.toException (userError msg)))
instance (Monad m) => MonadPlus (ExceptionT m) where
mzero = ExceptionT $ return (Left (E.toException (userError "")))
m `mplus` n = ExceptionT $ do
a <- runExceptionT m
case a of
Left _ -> runExceptionT n
Right r -> return (Right r)
instance (Functor m, Monad m) => Alternative (ExceptionT m) where
empty = mzero
(<|>) = mplus
instance (MonadFix m) => MonadFix (ExceptionT m) where
mfix f = ExceptionT $ mfix $ \a -> runExceptionT $ f $ case a of
Right r -> r
_ -> error "empty mfix argument"
instance (Monad m) => MonadException (ExceptionT m) where
throw e = ExceptionT $ return (Left (E.toException e))
m `catch` h = ExceptionT $ do
a <- runExceptionT m
case a of
Left l -> case E.fromException l of
Just e -> runExceptionT (h e)
Nothing -> return (Left l)
Right r -> return (Right r)
instance (MonadIO m) => MonadIO (ExceptionT m) where
liftIO m = ExceptionT $ liftIO $
fmap Right m `E.catch` \(e :: E.SomeException) -> return (Left e)
instance (MonadAsyncException m) => MonadAsyncException (ExceptionT m) where
mask act = ExceptionT $ mask $ \restore ->
runExceptionT $ act (mapExceptionT restore)
instance MonadException IO where
catch = E.catch
throw = E.throw
finally = E.finally
#if __GLASGOW_HASKELL__ >= 700
instance MonadAsyncException IO where
mask = E.mask
#else /* __GLASGOW_HASKELL__ < 700 */
instance MonadAsyncException IO where
mask act = do
b <- E.blocked
if b
then act id
else E.block $ act E.unblock
#endif /* __GLASGOW_HASKELL__ < 700 */
instance MonadException STM where
catch = catchSTM
throw = throwSTM
instance (MonadException m, Error e) =>
MonadException (ErrorT e m) where
throw = lift . throw
m `catch` h = mapErrorT (\m' -> m' `catch` \e -> runErrorT (h e)) m
act `finally` sequel =
mapErrorT (\act' -> act' `finally` runErrorT sequel) act
instance (MonadException m) =>
MonadException (ExceptT e' m) where
throw = lift . throw
m `catch` h = mapExceptT (\m' -> m' `catch` \e -> runExceptT (h e)) m
act `finally` sequel =
mapExceptT (\act' -> act' `finally` runExceptT sequel) act
instance (MonadException m) =>
MonadException (IdentityT m) where
throw = lift . throw
m `catch` h = mapIdentityT (\m' -> m' `catch` \e -> runIdentityT (h e)) m
instance MonadException m =>
MonadException (ListT m) where
throw = lift . throw
m `catch` h = mapListT (\m' -> m' `catch` \e -> runListT (h e)) m
instance (MonadException m) =>
MonadException (MaybeT m) where
throw = lift . throw
m `catch` h = mapMaybeT (\m' -> m' `catch` \e -> runMaybeT (h e)) m
act `finally` sequel =
mapMaybeT (\act' -> act' `finally` runMaybeT sequel) act
instance (Monoid w, MonadException m) =>
MonadException (Lazy.RWST r w s m) where
throw = lift . throw
m `catch` h = Lazy.RWST $ \r s ->
Lazy.runRWST m r s `catch` \e -> Lazy.runRWST (h e) r s
instance (Monoid w, MonadException m) =>
MonadException (Strict.RWST r w s m) where
throw = lift . throw
m `catch` h = Strict.RWST $ \r s ->
Strict.runRWST m r s `catch` \e -> Strict.runRWST (h e) r s
instance (MonadException m) =>
MonadException (ReaderT r m) where
throw = lift . throw
m `catch` h = ReaderT $ \r ->
runReaderT m r `catch` \e -> runReaderT (h e) r
instance (MonadException m) =>
MonadException (Lazy.StateT s m) where
throw = lift . throw
m `catch` h = Lazy.StateT $ \s ->
Lazy.runStateT m s `catch` \e -> Lazy.runStateT (h e) s
instance (MonadException m) =>
MonadException (Strict.StateT s m) where
throw = lift . throw
m `catch` h = Strict.StateT $ \s ->
Strict.runStateT m s `catch` \e -> Strict.runStateT (h e) s
instance (Monoid w, MonadException m) =>
MonadException (Lazy.WriterT w m) where
throw = lift . throw
m `catch` h = Lazy.WriterT $
Lazy.runWriterT m `catch` \e -> Lazy.runWriterT (h e)
instance (Monoid w, MonadException m) =>
MonadException (Strict.WriterT w m) where
throw = lift . throw
m `catch` h = Strict.WriterT $
Strict.runWriterT m `catch` \e -> Strict.runWriterT (h e)
instance (MonadAsyncException m, Error e) =>
MonadAsyncException (ErrorT e m) where
mask act = ErrorT $ mask $ \restore ->
runErrorT $ act (mapErrorT restore)
instance (MonadAsyncException m) =>
MonadAsyncException (ExceptT e' m) where
mask act = ExceptT $ mask $ \restore ->
runExceptT $ act (mapExceptT restore)
instance (MonadAsyncException m) =>
MonadAsyncException (IdentityT m) where
mask act = IdentityT $ mask $ \restore ->
runIdentityT $ act (mapIdentityT restore)
instance (MonadAsyncException m) =>
MonadAsyncException (ListT m) where
mask act = ListT $ mask $ \restore ->
runListT $ act (mapListT restore)
instance (MonadAsyncException m) =>
MonadAsyncException (MaybeT m) where
mask act = MaybeT $ mask $ \restore ->
runMaybeT $ act (mapMaybeT restore)
instance (Monoid w, MonadAsyncException m) =>
MonadAsyncException (Lazy.RWST r w s m) where
mask act = Lazy.RWST $ \r s -> mask $ \restore ->
Lazy.runRWST (act (Lazy.mapRWST restore)) r s
instance (Monoid w, MonadAsyncException m) =>
MonadAsyncException (Strict.RWST r w s m) where
mask act = Strict.RWST $ \r s -> mask $ \restore ->
Strict.runRWST (act (Strict.mapRWST restore)) r s
instance (MonadAsyncException m) =>
MonadAsyncException (ReaderT r m) where
mask act = ReaderT $ \r -> mask $ \restore ->
runReaderT (act (mapReaderT restore)) r
instance (MonadAsyncException m) =>
MonadAsyncException (Lazy.StateT s m) where
mask act = Lazy.StateT $ \s -> mask $ \restore ->
Lazy.runStateT (act (Lazy.mapStateT restore)) s
instance (MonadAsyncException m) =>
MonadAsyncException (Strict.StateT s m) where
mask act = Strict.StateT $ \s -> mask $ \restore ->
Strict.runStateT (act (Strict.mapStateT restore)) s
instance (Monoid w, MonadAsyncException m) =>
MonadAsyncException (Lazy.WriterT w m) where
mask act = Lazy.WriterT $ mask $ \restore ->
Lazy.runWriterT $ act (Lazy.mapWriterT restore)
instance (Monoid w, MonadAsyncException m) =>
MonadAsyncException (Strict.WriterT w m) where
mask act = Strict.WriterT $ mask $ \restore ->
Strict.runWriterT $ act (Strict.mapWriterT restore)