{-# LANGUAGE RankNTypes, TypeFamilies, FlexibleContexts, FlexibleInstances,
MultiParamTypeClasses, UndecidableInstances, ScopedTypeVariables,
GeneralizedNewtypeDeriving, CPP, Trustworthy #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Pipes.Safe
(
SafeT
, runSafeT
, runSafeP
, ReleaseKey
, MonadSafe(..)
, onException
, tryP
, catchP
, finally
, bracket
, bracket_
, bracketOnError
, module Control.Monad.Catch
, module Control.Exception
) where
import Control.Applicative (Applicative, Alternative)
import Control.Exception(Exception(..), SomeException(..))
import qualified Control.Monad.Catch as C
import Control.Monad.Catch
( MonadCatch(..)
, MonadThrow(..)
, MonadMask(..)
#if MIN_VERSION_exceptions(0,10,0)
, ExitCase(..)
#endif
, mask_
, uninterruptibleMask_
, catchAll
, catchIOError
, catchJust
, catchIf
, Handler(..)
, catches
, handle
, handleAll
, handleIOError
, handleJust
, handleIf
, tryJust
, Exception(..)
, SomeException
)
import Control.Monad (MonadPlus, liftM)
import Control.Monad.Fail (MonadFail)
import Control.Monad.Fix (MonadFix)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Trans.Control (MonadBaseControl(..))
import Control.Monad.Trans.Class (MonadTrans(lift))
import qualified Control.Monad.Base as B
import qualified Control.Monad.Catch.Pure as E
import qualified Control.Monad.Trans.Identity as I
import qualified Control.Monad.Cont.Class as CC
import qualified Control.Monad.Error.Class as EC
import qualified Control.Monad.Primitive as Prim
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.RWS.Lazy as RWS
import qualified Control.Monad.Trans.RWS.Strict as RWS'
import qualified Control.Monad.Trans.State.Lazy as S
import qualified Control.Monad.Trans.State.Strict as S'
import qualified Control.Monad.State.Class as SC
import qualified Control.Monad.Trans.Writer.Lazy as W
import qualified Control.Monad.Trans.Writer.Strict as W'
import qualified Control.Monad.Writer.Class as WC
#if MIN_VERSION_base(4,6,0)
import Data.IORef (IORef, newIORef, readIORef, writeIORef, atomicModifyIORef')
#else
import Data.IORef (IORef, newIORef, readIORef, writeIORef, atomicModifyIORef)
#endif
import qualified Data.Map as M
import Data.Monoid (Monoid)
import Pipes (Proxy, Effect, Effect', runEffect)
import Pipes.Internal (Proxy(..))
data Restore m = Unmasked | Masked (forall x . m x -> m x)
liftMask
:: forall m a' a b' b r . (MonadIO m, MonadCatch m)
=> (forall s . ((forall x . m x -> m x) -> m s) -> m s)
-> ((forall x . Proxy a' a b' b m x -> Proxy a' a b' b m x)
-> Proxy a' a b' b m r)
-> Proxy a' a b' b m r
liftMask maskVariant k = do
ioref <- liftIO $ newIORef Unmasked
let
loop :: Proxy a' a b' b m r -> Proxy a' a b' b m r
loop (Request a' fa ) = Request a' (loop . fa )
loop (Respond b fb') = Respond b (loop . fb')
loop (M m) = M $ maskVariant $ \unmaskVariant -> do
liftIO $ writeIORef ioref $ Masked unmaskVariant
m >>= chunk >>= return . loop
loop (Pure r) = Pure r
unmask :: forall q. Proxy a' a b' b m q -> Proxy a' a b' b m q
unmask (Request a' fa ) = Request a' (unmask . fa )
unmask (Respond b fb') = Respond b (unmask . fb')
unmask (M m) = M $ do
unmaskVariant <- liftIO $ do
Masked unmaskVariant <- readIORef ioref
return unmaskVariant
unmaskVariant (m >>= chunk >>= return . unmask)
unmask (Pure q) = Pure q
chunk :: forall s. Proxy a' a b' b m s -> m (Proxy a' a b' b m s)
chunk (M m) = m >>= chunk
chunk s = return s
loop $ k unmask
instance (MonadMask m, MonadIO m) => MonadMask (Proxy a' a b' b m) where
mask = liftMask mask
uninterruptibleMask = liftMask uninterruptibleMask
#if MIN_VERSION_exceptions(0,10,0)
generalBracket acquire release_ use = mask $ \unmasked -> do
a <- acquire
let action = do
b <- use a
return (ExitCaseSuccess b, ExitCaseSuccess_ b)
let handler e = return (ExitCaseException e, ExitCaseException_ e)
(exitCase, exitCase_) <- unmasked action `catch` handler
c <- release_ a exitCase
case exitCase_ of
ExitCaseException_ e -> throwM e
ExitCaseSuccess_ b -> return (b, c)
data ExitCase_ a = ExitCaseSuccess_ a | ExitCaseException_ SomeException
#endif
data Finalizers m = Finalizers
{ _nextKey :: !Integer
, _finalizers :: !(M.Map Integer (m ()))
}
newtype SafeT m r = SafeT { unSafeT :: R.ReaderT (IORef (Maybe (Finalizers m))) m r }
deriving
( Functor
, Applicative
, Alternative
, Monad
#if MIN_VERSION_transformers(0,5,0)
, MonadFail
#endif
, MonadPlus
, MonadFix
, EC.MonadError e
, SC.MonadState s
, WC.MonadWriter w
, CC.MonadCont
, MonadThrow
, MonadCatch
, MonadMask
, MonadIO
, B.MonadBase b
)
instance MonadTrans SafeT where
lift m = SafeT (lift m)
instance MonadBaseControl b m => MonadBaseControl b (SafeT m) where
#if MIN_VERSION_monad_control(1,0,0)
type StM (SafeT m) a = StM m a
liftBaseWith f = SafeT $ R.ReaderT $ \reader' ->
liftBaseWith $ \runInBase ->
f $ runInBase . (\(SafeT r) -> R.runReaderT r reader' )
restoreM = SafeT . R.ReaderT . const . restoreM
#else
newtype StM (SafeT m) a = StMT (StM m a)
liftBaseWith f = SafeT $ R.ReaderT $ \reader' ->
liftBaseWith $ \runInBase ->
f $ liftM StMT . runInBase . \(SafeT r) -> R.runReaderT r reader'
restoreM (StMT base) = SafeT $ R.ReaderT $ const $ restoreM base
#endif
instance Prim.PrimMonad m => Prim.PrimMonad (SafeT m) where
type PrimState (SafeT m) = Prim.PrimState m
primitive = lift . Prim.primitive
{-# INLINE primitive #-}
runSafeT :: (MonadMask m, MonadIO m) => SafeT m r -> m r
runSafeT m = C.bracket
(liftIO $ newIORef $! Just $! Finalizers 0 M.empty)
(\ioref -> do
#if MIN_VERSION_base(4,6,0)
mres <- liftIO $ atomicModifyIORef' ioref $ \val ->
#else
mres <- liftIO $ atomicModifyIORef ioref $ \val ->
#endif
(Nothing, val)
case mres of
Nothing -> error "runSafeT's resources were freed by another"
Just (Finalizers _ fs) -> mapM snd (M.toDescList fs) )
(R.runReaderT (unSafeT m))
{-# INLINABLE runSafeT #-}
runSafeP :: (MonadMask m, MonadIO m) => Effect (SafeT m) r -> Effect' m r
runSafeP = lift . runSafeT . runEffect
{-# INLINABLE runSafeP #-}
newtype ReleaseKey = ReleaseKey { unlock :: Integer }
class (MonadCatch m, MonadMask m, MonadIO m, MonadIO (Base m)) => MonadSafe m where
type Base (m :: * -> *) :: * -> *
liftBase :: Base m r -> m r
register :: Base m () -> m ReleaseKey
release :: ReleaseKey -> m ()
instance (MonadIO m, MonadCatch m, MonadMask m) => MonadSafe (SafeT m) where
type Base (SafeT m) = m
liftBase = lift
register io = do
ioref <- SafeT R.ask
liftIO $ do
#if MIN_VERSION_base(4,6,0)
n <- atomicModifyIORef' ioref $ \val ->
#else
n <- atomicModifyIORef ioref $ \val ->
#endif
case val of
Nothing -> error "register: SafeT block is closed"
Just (Finalizers n fs) ->
(Just $! Finalizers (n + 1) (M.insert n io fs), n)
return (ReleaseKey n)
release key = do
ioref <- SafeT R.ask
#if MIN_VERSION_base(4,6,0)
liftIO $ atomicModifyIORef' ioref $ \val ->
#else
liftIO $ atomicModifyIORef ioref $ \val ->
#endif
case val of
Nothing -> error "release: SafeT block is closed"
Just (Finalizers n fs) ->
(Just $! Finalizers n (M.delete (unlock key) fs), ())
instance MonadSafe m => MonadSafe (Proxy a' a b' b m) where
type Base (Proxy a' a b' b m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (I.IdentityT m) where
type Base (I.IdentityT m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (E.CatchT m) where
type Base (E.CatchT m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (R.ReaderT i m) where
type Base (R.ReaderT i m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (S.StateT s m) where
type Base (S.StateT s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (S'.StateT s m) where
type Base (S'.StateT s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W.WriterT w m) where
type Base (W.WriterT w m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W'.WriterT w m) where
type Base (W'.WriterT w m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS.RWST i w s m) where
type Base (RWS.RWST i w s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS'.RWST i w s m) where
type Base (RWS'.RWST i w s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
onException :: (MonadSafe m) => m a -> Base m b -> m a
m1 `onException` io = do
key <- register (io >> return ())
r <- m1
release key
return r
{-# INLINABLE onException #-}
finally :: (MonadSafe m) => m a -> Base m b -> m a
m1 `finally` after = bracket_ (return ()) after m1
{-# INLINABLE finally #-}
bracket :: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracket before after action = mask $ \restore -> do
h <- liftBase before
r <- restore (action h) `onException` after h
_ <- liftBase (after h)
return r
{-# INLINABLE bracket #-}
bracket_ :: (MonadSafe m) => Base m a -> Base m b -> m c -> m c
bracket_ before after action = bracket before (\_ -> after) (\_ -> action)
{-# INLINABLE bracket_ #-}
bracketOnError
:: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracketOnError before after action = mask $ \restore -> do
h <- liftBase before
restore (action h) `onException` after h
{-# INLINABLE bracketOnError #-}
tryP :: (MonadSafe m, Exception e)
=> Proxy a' a b' b m r -> Proxy a' a b' b m (Either e r)
tryP p = case p of
Request a' fa -> Request a' (\a -> tryP (fa a))
Respond b fb' -> Respond b (\b' -> tryP (fb' b'))
M m -> M $ C.try m >>= \eres -> return $ case eres of
Left e -> Pure (Left e)
Right a -> tryP a
Pure r -> Pure (Right r)
catchP :: (MonadSafe m, Exception e)
=> Proxy a' a b' b m r -> (e -> Proxy a' a b' b m r)
-> Proxy a' a b' b m r
catchP p0 f = go p0
where
go p = case p of
Request a' fa -> Request a' (\a -> go (fa a))
Respond b fb' -> Respond b (\b' -> go (fb' b'))
M m -> M $ C.catch (liftM go m) (return . f)
Pure r -> Pure r