module Control.Monad.CatchIO
(
MonadCatchIO(..)
, E.Exception(..)
, throw
, try, tryJust
, Handler(..), catches
, bracket
, bracket_
, bracketOnError
, finally
, onException
)
where
import Prelude hiding ( catch )
import Control.Applicative ((<$>))
import qualified Control.Exception.Extensible as E
import Control.Monad.IO.Class (MonadIO,liftIO)
import Control.Monad.Trans.Cont (ContT(ContT) ,runContT ,mapContT )
import Control.Monad.Trans.Error (ErrorT ,runErrorT ,mapErrorT ,Error)
import Control.Monad.Trans.Identity (IdentityT ,runIdentityT,mapIdentityT)
import Control.Monad.Trans.List (ListT(ListT) ,runListT ,mapListT )
import Control.Monad.Trans.Maybe (MaybeT ,runMaybeT ,mapMaybeT )
import Control.Monad.Trans.RWS (RWST(RWST) ,runRWST ,mapRWST )
import qualified Control.Monad.Trans.RWS.Strict as Strict (RWST(RWST) ,runRWST ,mapRWST )
import Control.Monad.Trans.Reader (ReaderT(ReaderT),runReaderT ,mapReaderT )
import Control.Monad.Trans.State (StateT(StateT) ,runStateT ,mapStateT )
import qualified Control.Monad.Trans.State.Strict as Strict (StateT(StateT) ,runStateT ,mapStateT )
import Control.Monad.Trans.Writer (WriterT ,runWriterT ,mapWriterT )
import qualified Control.Monad.Trans.Writer.Strict as Strict (WriterT ,runWriterT ,mapWriterT )
import Data.Monoid (Monoid)
import GHC.Base (maskAsyncExceptions#)
import GHC.IO (unsafeUnmask,IO(IO))
class MonadIO m => MonadCatchIO m where
catch :: E.Exception e => m a -> (e -> m a) -> m a
block :: m a -> m a
unblock :: m a -> m a
instance MonadCatchIO IO where
catch = E.catch
block = \ (IO io) -> IO $ maskAsyncExceptions# io
unblock = unsafeUnmask
instance MonadCatchIO m => MonadCatchIO (ContT r m) where
m `catch` f = ContT $ \c -> runContT m c `catch` \e -> runContT (f e) c
block = mapContT block
unblock = mapContT unblock
instance (MonadCatchIO m, Error e) => MonadCatchIO (ErrorT e m) where
m `catch` f = mapErrorT (\m' -> m' `catch` \e -> runErrorT $ f e) m
block = mapErrorT block
unblock = mapErrorT unblock
instance (MonadCatchIO m) => MonadCatchIO (IdentityT m) where
m `catch` f = mapIdentityT (\m' -> m' `catch` \e -> runIdentityT $ f e) m
block = mapIdentityT block
unblock = mapIdentityT unblock
instance MonadCatchIO m => MonadCatchIO (ListT m) where
m `catch` f = ListT $ runListT m `catch` \e -> runListT (f e)
block = mapListT block
unblock = mapListT unblock
instance (MonadCatchIO m) => MonadCatchIO (MaybeT m) where
m `catch` f = mapMaybeT (\m' -> m' `catch` \e -> runMaybeT $ f e) m
block = mapMaybeT block
unblock = mapMaybeT unblock
instance (Monoid w, MonadCatchIO m) => MonadCatchIO (RWST r w s m) where
m `catch` f = RWST $ \r s -> runRWST m r s `catch` \e -> runRWST (f e) r s
block = mapRWST block
unblock = mapRWST unblock
instance (Monoid w, MonadCatchIO m) => MonadCatchIO (Strict.RWST r w s m) where
m `catch` f = Strict.RWST $ \r s -> Strict.runRWST m r s `catch` \e -> Strict.runRWST (f e) r s
block = Strict.mapRWST block
unblock = Strict.mapRWST unblock
instance MonadCatchIO m => MonadCatchIO (ReaderT r m) where
m `catch` f = ReaderT $ \r -> runReaderT m r `catch` \e -> runReaderT (f e) r
block = mapReaderT block
unblock = mapReaderT unblock
instance MonadCatchIO m => MonadCatchIO (StateT s m) where
m `catch` f = StateT $ \s -> runStateT m s `catch` \e -> runStateT (f e) s
block = mapStateT block
unblock = mapStateT unblock
instance MonadCatchIO m => MonadCatchIO (Strict.StateT s m) where
m `catch` f = Strict.StateT $ \s -> Strict.runStateT m s `catch` \e -> Strict.runStateT (f e) s
block = Strict.mapStateT block
unblock = Strict.mapStateT unblock
instance (Monoid w, MonadCatchIO m) => MonadCatchIO (WriterT w m) where
m `catch` f = mapWriterT (\m' -> m' `catch` \e -> runWriterT $ f e) m
block = mapWriterT block
unblock = mapWriterT unblock
instance (Monoid w, MonadCatchIO m) => MonadCatchIO (Strict.WriterT w m) where
m `catch` f = Strict.mapWriterT (\m' -> m' `catch` \e -> Strict.runWriterT $ f e) m
block = Strict.mapWriterT block
unblock = Strict.mapWriterT unblock
throw :: (MonadIO m, E.Exception e) => e -> m a
throw = liftIO . E.throwIO
try :: (MonadCatchIO m, Functor m, E.Exception e) => m a -> m (Either e a)
try a = catch (Right <$> a) (return . Left)
tryJust :: (MonadCatchIO m, Functor m, E.Exception e)
=> (e -> Maybe b) -> m a -> m (Either b a)
tryJust p a = do
r <- try a
case r of
Right v -> return (Right v)
Left e -> case p e of
Nothing -> throw e `asTypeOf` return (Left undefined)
Just b -> return (Left b)
data Handler m a = forall e . E.Exception e => Handler (e -> m a)
catches :: MonadCatchIO m => m a -> [Handler m a] -> m a
catches a handlers = a `catch` handler where
handler e = foldr tryH (throw e) handlers where
tryH (Handler h) res = maybe res h $ E.fromException e
bracket :: MonadCatchIO m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket before after thing = block $ do
a <- before
r <- unblock (thing a) `onException` after a
_ <- after a
return r
onException :: MonadCatchIO m => m a -> m b -> m a
onException a onEx = a `catch` (\ (e :: E.SomeException) -> onEx >> throw e)
bracket_ :: MonadCatchIO m
=> m a
-> m b
-> m c
-> m c
bracket_ before after thing = block $ do
_ <- before
r <- unblock thing `onException` after
_ <- after
return r
finally :: MonadCatchIO m
=> m a
-> m b
-> m a
thing `finally` after = block $ do
r <- unblock thing `onException` after
_ <- after
return r
bracketOnError :: MonadCatchIO m
=> m a
-> (a -> m b)
-> (a -> m c)
-> m c
bracketOnError before after thing = block $ do
a <- before
unblock (thing a) `onException` after a