{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.Monad.Trans.RSS.Lazy (
    -- * The RWS monad
    RSS,
    rss,
    runRSS,
    evalRSS,
    execRSS,
    withRSS,
    -- * The RSST monad transformer
    RSST,
    runRSST,
    evalRSST,
    execRSST,
    withRSST,
    -- * Helpers
    liftCatch
  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Except
import Control.Monad.Signatures
import Data.Functor.Identity

import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.RWS

-- | A monad containing an environment of type @r@, output of type @w@
-- and an updatable state of type @s@.
type RSS r w s = RSST r w s Identity

-- | Construct an RSS computation from a function.
-- (The inverse of 'runRSS'.)
rss :: Monoid w => (r -> s -> (a, s, w)) -> RSS r w s a
rss :: forall w r s a. Monoid w => (r -> s -> (a, s, w)) -> RSS r w s a
rss r -> s -> (a, s, w)
f = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s
s,w
w) -> let (a
a,s
s',w
w') = r -> s -> (a, s, w)
f r
r s
s
                           in  forall a. a -> Identity a
Identity (a
a, (s
s', w
w forall a. Semigroup a => a -> a -> a
<> w
w'))

-- | Unwrap an RSS computation as a function.
-- (The inverse of 'rws'.)
runRSS :: Monoid w => RSS r w s a -> r -> s -> (a,s,w)
runRSS :: forall w r s a. Monoid w => RSS r w s a -> r -> s -> (a, s, w)
runRSS RSS r w s a
m r
r s
s = forall a. Identity a -> a
runIdentity (forall w (m :: * -> *) r s a.
(Monoid w, Monad m) =>
RSST r w s m a -> r -> s -> m (a, s, w)
runRSST RSS r w s a
m r
r s
s)

-- | Evaluate a computation with the given initial state and environment,
-- returning the final value and output, discarding the final state.
evalRSS :: Monoid w
        => RSS r w s a  -- ^RWS computation to execute
        -> r            -- ^initial environment
        -> s            -- ^initial value
        -> (a, w)       -- ^final value and output
evalRSS :: forall w r s a. Monoid w => RSS r w s a -> r -> s -> (a, w)
evalRSS RSS r w s a
m r
r s
s = let
    (a
a, s
_, w
w) = forall w r s a. Monoid w => RSS r w s a -> r -> s -> (a, s, w)
runRSS RSS r w s a
m r
r s
s
    in (a
a, w
w)

-- | Evaluate a computation with the given initial state and environment,
-- returning the final state and output, discarding the final value.
execRSS :: Monoid w
        => RSS r w s a  -- ^RWS computation to execute
        -> r            -- ^initial environment
        -> s            -- ^initial value
        -> (s, w)       -- ^final state and output
execRSS :: forall w r s a. Monoid w => RSS r w s a -> r -> s -> (s, w)
execRSS RSS r w s a
m r
r s
s = let
    (a
_, s
s', w
w) = forall w r s a. Monoid w => RSS r w s a -> r -> s -> (a, s, w)
runRSS RSS r w s a
m r
r s
s
    in (s
s', w
w)

-- and state modified by applying @f@.
--
-- * @'runRSS' ('withRSS' f m) r s = 'uncurry' ('runRSS' m) (f r s)@
withRSS :: (r' -> s -> (r, s)) -> RSS r w s a -> RSS r' w s a
withRSS :: forall r' s r w a.
(r' -> s -> (r, s)) -> RSS r w s a -> RSS r' w s a
withRSS = forall r' s r w (m :: * -> *) a.
(r' -> s -> (r, s)) -> RSST r w s m a -> RSST r' w s m a
withRSST

---------------------------------------------------------------------------
-- | A monad transformer adding reading an environment of type @r@,
-- collecting an output of type @w@ and updating a state of type @s@
-- to an inner monad @m@.
newtype RSST r w s m a = RSST { forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' :: r -> (s,w) -> m (a, (s, w)) }

runRSST :: (Monoid w, Monad m) => RSST r w s m a -> r -> s -> m (a, s, w)
runRSST :: forall w (m :: * -> *) r s a.
(Monoid w, Monad m) =>
RSST r w s m a -> r -> s -> m (a, s, w)
runRSST RSST r w s m a
m r
r s
s = do
    ~(a
a,(s
s',w
w)) <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s
s,forall a. Monoid a => a
mempty)
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
a,s
s',w
w)

-- | Evaluate a computation with the given initial state and environment,
-- returning the final value and output, discarding the final state.
evalRSST :: (Monad m, Monoid w)
            => RSST r w s m a     -- ^computation to execute
            -> r                  -- ^initial environment
            -> s                  -- ^initial value
            -> m (a,w)          -- ^computation yielding final value and output
evalRSST :: forall (m :: * -> *) w r s a.
(Monad m, Monoid w) =>
RSST r w s m a -> r -> s -> m (a, w)
evalRSST RSST r w s m a
m r
r s
s = do
    ~(a
a, (s
_, w
w)) <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s
s,forall a. Monoid a => a
mempty)
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w
w)

-- | Evaluate a computation with the given initial state and environment,
-- returning the final state and output, discarding the final value.
execRSST :: (Monad m, Monoid w)
            => RSST r w s m a      -- ^computation to execute
            -> r                   -- ^initial environment
            -> s                   -- ^initial value
            -> m (s, w)          -- ^computation yielding final state and output
execRSST :: forall (m :: * -> *) w r s a.
(Monad m, Monoid w) =>
RSST r w s m a -> r -> s -> m (s, w)
execRSST RSST r w s m a
m r
r s
s = do
        ~(a
_, (s
s', w
w)) <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s
s,forall a. Monoid a => a
mempty)
        forall (m :: * -> *) a. Monad m => a -> m a
return (s
s', w
w)

-- | @'withRSST' f m@ executes action @m@ with an initial environment
-- and state modified by applying @f@.
--
-- * @'runRSST' ('withRSST' f m) r s = 'uncurry' ('runRSST' m) (f r s)@
withRSST :: (r' -> s -> (r, s)) -> RSST r w s m a -> RSST r' w s m a
withRSST :: forall r' s r w (m :: * -> *) a.
(r' -> s -> (r, s)) -> RSST r w s m a -> RSST r' w s m a
withRSST r' -> s -> (r, s)
f RSST r w s m a
m = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r'
r (s
s,w
w) ->
    let (r
r',s
s') = r' -> s -> (r, s)
f r'
r s
s
    in  forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r' (s
s',w
w)

instance (Functor m) => Functor (RSST r w s m) where
    fmap :: forall a b. (a -> b) -> RSST r w s m a -> RSST r w s m b
fmap a -> b
f RSST r w s m a
m = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s ->
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ ~(a
a, (s
s', w
w)) -> (a -> b
f a
a, (s
s', w
w))) forall a b. (a -> b) -> a -> b
$ forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s, w)
s

instance (Monad m) => Monad (RSST r w s m) where
    return :: forall a. a -> RSST r w s m a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
    RSST r w s m a
m >>= :: forall a b.
RSST r w s m a -> (a -> RSST r w s m b) -> RSST r w s m b
>>= a -> RSST r w s m b
k  = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> do
        ~(a
a, (s
s', w
w))  <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s, w)
s
        forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' (a -> RSST r w s m b
k a
a) r
r (s
s',w
w)

instance (MonadFail m) => MonadFail (RSST r w s m) where
    fail :: forall a. String -> RSST r w s m a
fail String
msg = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s, w)
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg

instance (MonadPlus m) => MonadPlus (RSST r w s m) where
    mzero :: forall a. RSST r w s m a
mzero = forall (f :: * -> *) a. Alternative f => f a
empty
    mplus :: forall a. RSST r w s m a -> RSST r w s m a -> RSST r w s m a
mplus = forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)

instance (Functor m, Monad m) => Applicative (RSST r w s m) where
    pure :: forall a. a -> RSST r w s m a
pure a
a = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s, w)
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a, (s, w)
s)
    <*> :: forall a b.
RSST r w s m (a -> b) -> RSST r w s m a -> RSST r w s m b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance (Functor m, MonadPlus m) => Alternative (RSST r w s m) where
    empty :: forall a. RSST r w s m a
empty = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s, w)
_ -> forall (f :: * -> *) a. Alternative f => f a
empty
    RSST r w s m a
m <|> :: forall a. RSST r w s m a -> RSST r w s m a -> RSST r w s m a
<|> RSST r w s m a
n = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s, w)
s forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
n r
r (s, w)
s

instance (MonadFix m) => MonadFix (RSST r w s m) where
    mfix :: forall a. (a -> RSST r w s m a) -> RSST r w s m a
mfix a -> RSST r w s m a
f = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix forall a b. (a -> b) -> a -> b
$ \ ~(a
a, (s, w)
_) -> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' (a -> RSST r w s m a
f a
a) r
r (s, w)
s

instance MonadTrans (RSST r w s) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> RSST r w s m a
lift m a
m = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s, w)
s -> do
        a
a <- m a
m
        forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, (s, w)
s)

instance (MonadIO m) => MonadIO (RSST r w s m) where
    liftIO :: forall a. IO a -> RSST r w s m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance Monad m => MonadState s (RSST r w s m) where
    get :: RSST r w s m s
get = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s
s,w
w) -> forall (m :: * -> *) a. Monad m => a -> m a
return (s
s,(s
s,w
w))
    put :: s -> RSST r w s m ()
put s
ns = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s
_,w
w) -> forall (m :: * -> *) a. Monad m => a -> m a
return ((),(s
ns,w
w))
    state :: forall a. (s -> (a, s)) -> RSST r w s m a
state s -> (a, s)
f = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s
s,w
w) -> case s -> (a, s)
f s
s of
                                      (a
a,s
s') -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, (s
s', w
w))



instance Monad m => MonadReader r (RSST r w s m) where
    ask :: RSST r w s m r
ask = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (r
r, (s, w)
s)
    local :: forall a. (r -> r) -> RSST r w s m a -> RSST r w s m a
local r -> r
f RSST r w s m a
rw = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
rw (r -> r
f r
r) (s, w)
s
    reader :: forall a. (r -> a) -> RSST r w s m a
reader r -> a
f = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s, w)
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (r -> a
f r
r, (s, w)
s)

instance (Monoid w, Monad m) => MonadWriter w (RSST r w s m) where
    writer :: forall a. (a, w) -> RSST r w s m a
writer (a
a,w
w) = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell w
w forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    tell :: w -> RSST r w s m ()
tell w
w = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
_ (s
s, w
ow) ->
        let nw :: w
nw = w
ow forall a. Semigroup a => a -> a -> a
<> w
w
        in  forall (m :: * -> *) a. Monad m => a -> m a
return ((), (s
s, w
nw))
    listen :: forall a. RSST r w s m a -> RSST r w s m (a, w)
listen RSST r w s m a
rw = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s
s, w
w) -> do
        (a
a, (s
ns, w
nw)) <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
rw r
r (s
s, forall a. Monoid a => a
mempty)
        forall (m :: * -> *) a. Monad m => a -> m a
return ((a
a, w
nw), (s
ns, w
w forall a. Semigroup a => a -> a -> a
<> w
nw))
    pass :: forall a. RSST r w s m (a, w -> w) -> RSST r w s m a
pass RSST r w s m (a, w -> w)
rw = forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \r
r (s
s, w
w) -> do
        ( (a
a, w -> w
fw), (s
s', w
w') ) <- forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m (a, w -> w)
rw r
r (s
s, forall a. Monoid a => a
mempty)
        forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, (s
s', w
w forall a. Monoid a => a -> a -> a
`mappend` w -> w
fw w
w'))

instance (Monoid w, Monad m) => MonadRWS r w s (RSST r w s m)

instance (Monoid w, MonadError e m) => MonadError e (RSST r w s m) where
  throwError :: forall a. e -> RSST r w s m a
throwError = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: forall a. RSST r w s m a -> (e -> RSST r w s m a) -> RSST r w s m a
catchError = forall e (m :: * -> *) a s w r.
Catch e m (a, (s, w)) -> Catch e (RSST r w s m) a
liftCatch forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError

-- | Lift a @catchE@ operation to the new monad.
liftCatch :: Catch e m (a,(s,w)) -> Catch e (RSST r w s m) a
liftCatch :: forall e (m :: * -> *) a s w r.
Catch e m (a, (s, w)) -> Catch e (RSST r w s m) a
liftCatch Catch e m (a, (s, w))
catchE RSST r w s m a
m e -> RSST r w s m a
h =
  forall r w s (m :: * -> *) a.
(r -> (s, w) -> m (a, (s, w))) -> RSST r w s m a
RSST forall a b. (a -> b) -> a -> b
$ \ r
r (s, w)
s -> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' RSST r w s m a
m r
r (s, w)
s Catch e m (a, (s, w))
`catchE` \ e
e -> forall r w s (m :: * -> *) a.
RSST r w s m a -> r -> (s, w) -> m (a, (s, w))
runRSST' (e -> RSST r w s m a
h e
e) r
r (s, w)
s
{-# INLINE liftCatch #-}