{-# LANGUAGE CPP #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif

#ifndef MIN_VERSION_transformers
#define MIN_VERSION_transformers(x,y,z) 1
#endif

#ifndef MIN_VERSION_mtl
#define MIN_VERSION_mtl(x,y,z) 1
#endif

--------------------------------------------------------------------

-- |

-- Copyright   :  (C) Edward Kmett 2013-2015, (c) Google Inc. 2012

-- License     :  BSD-style (see the file LICENSE)

-- Maintainer  :  Edward Kmett <ekmett@gmail.com>

-- Stability   :  experimental

-- Portability :  non-portable

--

-- This module supplies a \'pure\' monad transformer that can be used for

-- mock-testing code that throws exceptions, so long as those exceptions

-- are always thrown with 'throwM'.

--

-- Do not mix 'CatchT' with 'IO'. Choose one or the other for the

-- bottom of your transformer stack!

--------------------------------------------------------------------


module Control.Monad.Catch.Pure (
    -- * Transformer

    -- $transformer

    CatchT(..), Catch
  , runCatch
  , mapCatchT

  -- * Typeclass

  -- $mtl

  , module Control.Monad.Catch
  ) where

#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706)
import Prelude hiding (foldr)
#else
import Prelude hiding (catch, foldr)
#endif

import Control.Applicative
import Control.Monad.Catch
import qualified Control.Monad.Fail as Fail
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad (MonadPlus(..), ap, liftM)
import Control.Monad.Reader (MonadReader(..))
import Control.Monad.RWS (MonadRWS)
import Control.Monad.State (MonadState(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Writer (MonadWriter(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable
import Data.Monoid (Monoid(..))
#endif
import Data.Functor.Identity
import Data.Traversable as Traversable

------------------------------------------------------------------------------

-- $mtl

-- The mtl style typeclass

------------------------------------------------------------------------------


------------------------------------------------------------------------------

-- $transformer

-- The @transformers@-style monad transfomer

------------------------------------------------------------------------------


-- | Add 'Exception' handling abilities to a 'Monad'.

--

-- This should /never/ be used in combination with 'IO'. Think of 'CatchT'

-- as an alternative base monad for use with mocking code that solely throws

-- exceptions via 'throwM'.

--

-- Note: that 'IO' monad has these abilities already, so stacking 'CatchT' on top

-- of it does not add any value and can possibly be confusing:

--

-- >>> (error "Hello!" :: IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)

-- Hello!

--

-- >>> runCatchT $ (error "Hello!" :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)

-- *** Exception: Hello!

--

-- >>> runCatchT $ (throwM (ErrorCall "Hello!") :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)

-- Hello!


newtype CatchT m a = CatchT { forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT :: m (Either SomeException a) }

type Catch = CatchT Identity

runCatch :: Catch a -> Either SomeException a
runCatch :: forall a. Catch a -> Either SomeException a
runCatch = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT

instance Monad m => Functor (CatchT m) where
  fmap :: forall a b. (a -> b) -> CatchT m a -> CatchT m b
fmap a -> b
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) m (Either SomeException a)
m)

instance Monad m => Applicative (CatchT m) where
  pure :: forall a. a -> CatchT m a
pure a
a = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a))
  <*> :: forall a b. CatchT m (a -> b) -> CatchT m a -> CatchT m b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad m => Monad (CatchT m) where
  return :: forall a. a -> CatchT m a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  CatchT m (Either SomeException a)
m >>= :: forall a b. CatchT m a -> (a -> CatchT m b) -> CatchT m b
>>= a -> CatchT m b
k = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
e -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left SomeException
e)
    Right a
a -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (a -> CatchT m b
k a
a)
#if !(MIN_VERSION_base(4,13,0))
  fail = Fail.fail
#endif

instance Monad m => Fail.MonadFail (CatchT m) where
  fail :: forall a. String -> CatchT m a
fail = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IOError
userError

instance MonadFix m => MonadFix (CatchT m) where
  mfix :: forall a. (a -> CatchT m a) -> CatchT m a
mfix a -> CatchT m a
f = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix forall a b. (a -> b) -> a -> b
$ \Either SomeException a
a -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ a -> CatchT m a
f forall a b. (a -> b) -> a -> b
$ case Either SomeException a
a of
    Right a
r -> a
r
    Either SomeException a
_       -> forall a. HasCallStack => String -> a
error String
"empty mfix argument"

instance Foldable m => Foldable (CatchT m) where
  foldMap :: forall m a. Monoid m => (a -> m) -> CatchT m a -> m
foldMap a -> m
f (CatchT m (Either SomeException a)
m) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {t} {t} {a}. Monoid t => (t -> t) -> Either a t -> t
foldMapEither a -> m
f) m (Either SomeException a)
m where
    foldMapEither :: (t -> t) -> Either a t -> t
foldMapEither t -> t
g (Right t
a) = t -> t
g t
a
    foldMapEither t -> t
_ (Left a
_) = forall a. Monoid a => a
mempty

instance (Monad m, Traversable m) => Traversable (CatchT m) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> CatchT m a -> f (CatchT m b)
traverse a -> f b
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
Traversable.traverse (forall {f :: * -> *} {t} {a} {a}.
Applicative f =>
(t -> f a) -> Either a t -> f (Either a a)
traverseEither a -> f b
f) m (Either SomeException a)
m where
    traverseEither :: (t -> f a) -> Either a t -> f (Either a a)
traverseEither t -> f a
g (Right t
a) = forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> f a
g t
a
    traverseEither t -> f a
_ (Left a
e) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left a
e)

instance Monad m => Alternative (CatchT m) where
  empty :: forall a. CatchT m a
empty = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  <|> :: forall a. CatchT m a -> CatchT m a -> CatchT m a
(<|>) = forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus

instance Monad m => MonadPlus (CatchT m) where
  mzero :: forall a. CatchT m a
mzero = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall e. Exception e => e -> SomeException
toException forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
""
  mplus :: forall a. CatchT m a -> CatchT m a -> CatchT m a
mplus (CatchT m (Either SomeException a)
m) (CatchT m (Either SomeException a)
n) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
_ -> m (Either SomeException a)
n
    Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a)

instance MonadTrans CatchT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> CatchT m a
lift m a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    a
a <- m a
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right a
a

instance MonadIO m => MonadIO (CatchT m) where
  liftIO :: forall a. IO a -> CatchT m a
liftIO IO a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    a
a <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right a
a

instance Monad m => MonadThrow (CatchT m) where
  throwM :: forall e a. (HasCallStack, Exception e) => e -> CatchT m a
throwM = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException
instance Monad m => MonadCatch (CatchT m) where
  catch :: forall e a.
(HasCallStack, Exception e) =>
CatchT m a -> (e -> CatchT m a) -> CatchT m a
catch (CatchT m (Either SomeException a)
m) e -> CatchT m a
c = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
e -> case forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e of
      Just e
e' -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (e -> CatchT m a
c e
e')
      Maybe e
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left SomeException
e)
    Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a)
-- | Note: This instance is only valid if the underlying monad has a single

-- exit point!

--

-- For example, @IO@ or @Either@ would be invalid base monads, but

-- @Reader@ or @State@ would be acceptable.

instance Monad m => MonadMask (CatchT m) where
  mask :: forall b.
HasCallStack =>
((forall a. CatchT m a -> CatchT m a) -> CatchT m b) -> CatchT m b
mask (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a = (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a forall a. a -> a
id
  uninterruptibleMask :: forall b.
HasCallStack =>
((forall a. CatchT m a -> CatchT m a) -> CatchT m b) -> CatchT m b
uninterruptibleMask (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a = (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a forall a. a -> a
id
  generalBracket :: forall a b c.
HasCallStack =>
CatchT m a
-> (a -> ExitCase b -> CatchT m c)
-> (a -> CatchT m b)
-> CatchT m (b, c)
generalBracket CatchT m a
acquire a -> ExitCase b -> CatchT m c
release a -> CatchT m b
use = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    Either SomeException a
eresource <- forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT CatchT m a
acquire
    case Either SomeException a
eresource of
      Left SomeException
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left SomeException
e
      Right a
resource -> do
        Either SomeException b
eb <- forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (a -> CatchT m b
use a
resource)
        case Either SomeException b
eb of
          Left SomeException
e -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ do
            c
_ <- a -> ExitCase b -> CatchT m c
release a
resource (forall a. SomeException -> ExitCase a
ExitCaseException SomeException
e)
            forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
e
          Right b
b -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ do
            c
c <- a -> ExitCase b -> CatchT m c
release a
resource (forall a. a -> ExitCase a
ExitCaseSuccess b
b)
            forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, c
c)

instance MonadState s m => MonadState s (CatchT m) where
  get :: CatchT m s
get = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> CatchT m ()
put = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => s -> m ()
put
#if MIN_VERSION_mtl(2,1,0)
  state :: forall a. (s -> (a, s)) -> CatchT m a
state = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state
#endif

instance MonadReader e m => MonadReader e (CatchT m) where
  ask :: CatchT m e
ask = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: forall a. (e -> e) -> CatchT m a -> CatchT m a
local e -> e
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local e -> e
f m (Either SomeException a)
m)

instance MonadWriter w m => MonadWriter w (CatchT m) where
  tell :: w -> CatchT m ()
tell = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
  listen :: forall a. CatchT m a -> CatchT m (a, w)
listen = forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT forall a b. (a -> b) -> a -> b
$ \ m (Either SomeException a)
m -> do
    (Either SomeException a
a, w
w) <- forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (Either SomeException a)
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ a
r -> (a
r, w
w)) Either SomeException a
a
  pass :: forall a. CatchT m (a, w -> w) -> CatchT m a
pass = forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT forall a b. (a -> b) -> a -> b
$ \ m (Either SomeException (a, w -> w))
m -> forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass forall a b. (a -> b) -> a -> b
$ do
    Either SomeException (a, w -> w)
a <- m (Either SomeException (a, w -> w))
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! case Either SomeException (a, w -> w)
a of
        Left  SomeException
l      -> (forall a b. a -> Either a b
Left  SomeException
l, forall a. a -> a
id)
        Right (a
r, w -> w
f) -> (forall a b. b -> Either a b
Right a
r, w -> w
f)
#if MIN_VERSION_mtl(2,1,0)
  writer :: forall a. (a, w) -> CatchT m a
writer (a, w)
aw = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall a b. b -> Either a b
Right forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (a, w)
aw)
#endif

instance MonadRWS r w s m => MonadRWS r w s (CatchT m)

-- | Map the unwrapped computation using the given function.

--

-- @'runCatchT' ('mapCatchT' f m) = f ('runCatchT' m)@

mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b))
          -> CatchT m a
          -> CatchT n b
mapCatchT :: forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT m (Either SomeException a) -> n (Either SomeException b)
f CatchT m a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a) -> n (Either SomeException b)
f (forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT CatchT m a
m)