{-# LANGUAGE CPP                        #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.StateStack
-- Copyright   :  (c) 2011 Brent Yorgey
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  byorgey@cis.upenn.edu
--
-- A state monad which allows the state to be saved and restored on a
-- stack.
--
-- [Computation type:] Computations with implicit access to a
-- read/write state, with additional operations for pushing the
-- current state on a stack and later restoring the state from the top
-- of the stack.
--
-- [Binding strategy:] Same as for the usual state monad; the state
-- and accompanying stack of saved states are threaded through
-- computations.
--
-- [Useful for:] Remembering state while emitting commands for some
-- system which itself has saveable/restorable state, such as OpenGL
-- or Cairo.
--
-- Simple example:
--
-- > ghci> let p = get >>= liftIO . print
-- > ghci> evalStateStackT (put 2 >> p >> save >> put 3 >> p >> restore >> p) 0
-- > 2
-- > 3
-- > 2
--
-----------------------------------------------------------------------------

module Control.Monad.StateStack
       (
         -- * The @MonadStateStack@ class

         MonadStateStack(..)

         -- * The @StateStackT@ transformer

       , StateStackT(..), StateStack

         -- * Running @StateStackT@ and @StateStack@ computations

       , runStateStackT, evalStateStackT, execStateStackT
       , runStateStack,  evalStateStack,  execStateStack

       , liftState

       ) where

import           Control.Arrow                     (second)
import           Control.Arrow                     (first, (&&&))
import qualified Control.Monad.State               as St

import           Control.Monad
import           Control.Monad.Identity
import           Control.Monad.Trans
import           Control.Monad.Trans.Cont
import           Control.Monad.Trans.Except
import           Control.Monad.Trans.Maybe
import           Control.Monad.Trans.Reader        (ReaderT)
import           Control.Monad.Trans.State.Lazy    as Lazy
import           Control.Monad.Trans.State.Strict  as Strict
import           Control.Monad.Trans.Writer.Lazy   as Lazy
import           Control.Monad.Trans.Writer.Strict as Strict

import qualified Control.Monad.Cont.Class          as CC
import qualified Control.Monad.IO.Class            as IC


------------------------------------------------------------
--  Implementation
------------------------------------------------------------

-- | A monad transformer which adds a save/restorable state to an
--   existing monad.
newtype StateStackT s m a = StateStackT { forall s (m :: * -> *) a. StateStackT s m a -> StateT (s, [s]) m a
unStateStackT :: St.StateT (s,[s]) m a }
  deriving (forall a b. a -> StateStackT s m b -> StateStackT s m a
forall a b. (a -> b) -> StateStackT s m a -> StateStackT s m b
forall s (m :: * -> *) a b.
Functor m =>
a -> StateStackT s m b -> StateStackT s m a
forall s (m :: * -> *) a b.
Functor m =>
(a -> b) -> StateStackT s m a -> StateStackT s m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> StateStackT s m b -> StateStackT s m a
$c<$ :: forall s (m :: * -> *) a b.
Functor m =>
a -> StateStackT s m b -> StateStackT s m a
fmap :: forall a b. (a -> b) -> StateStackT s m a -> StateStackT s m b
$cfmap :: forall s (m :: * -> *) a b.
Functor m =>
(a -> b) -> StateStackT s m a -> StateStackT s m b
Functor, forall a. a -> StateStackT s m a
forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m a
forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
forall a b.
StateStackT s m (a -> b) -> StateStackT s m a -> StateStackT s m b
forall a b c.
(a -> b -> c)
-> StateStackT s m a -> StateStackT s m b -> StateStackT s m c
forall {s} {m :: * -> *}. Monad m => Functor (StateStackT s m)
forall s (m :: * -> *) a. Monad m => a -> StateStackT s m a
forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m a
forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m (a -> b) -> StateStackT s m a -> StateStackT s m b
forall s (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> StateStackT s m a -> StateStackT s m b -> StateStackT s m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m a
$c<* :: forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m a
*> :: forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
$c*> :: forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
liftA2 :: forall a b c.
(a -> b -> c)
-> StateStackT s m a -> StateStackT s m b -> StateStackT s m c
$cliftA2 :: forall s (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> StateStackT s m a -> StateStackT s m b -> StateStackT s m c
<*> :: forall a b.
StateStackT s m (a -> b) -> StateStackT s m a -> StateStackT s m b
$c<*> :: forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m (a -> b) -> StateStackT s m a -> StateStackT s m b
pure :: forall a. a -> StateStackT s m a
$cpure :: forall s (m :: * -> *) a. Monad m => a -> StateStackT s m a
Applicative, forall a. a -> StateStackT s m a
forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
forall a b.
StateStackT s m a -> (a -> StateStackT s m b) -> StateStackT s m b
forall s (m :: * -> *). Monad m => Applicative (StateStackT s m)
forall s (m :: * -> *) a. Monad m => a -> StateStackT s m a
forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> (a -> StateStackT s m b) -> StateStackT s m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> StateStackT s m a
$creturn :: forall s (m :: * -> *) a. Monad m => a -> StateStackT s m a
>> :: forall a b.
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
$c>> :: forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> StateStackT s m b -> StateStackT s m b
>>= :: forall a b.
StateStackT s m a -> (a -> StateStackT s m b) -> StateStackT s m b
$c>>= :: forall s (m :: * -> *) a b.
Monad m =>
StateStackT s m a -> (a -> StateStackT s m b) -> StateStackT s m b
Monad, forall s (m :: * -> *) a. Monad m => m a -> StateStackT s m a
forall (m :: * -> *) a. Monad m => m a -> StateStackT s m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: forall (m :: * -> *) a. Monad m => m a -> StateStackT s m a
$clift :: forall s (m :: * -> *) a. Monad m => m a -> StateStackT s m a
MonadTrans, forall a. IO a -> StateStackT s m a
forall {s} {m :: * -> *}. MonadIO m => Monad (StateStackT s m)
forall s (m :: * -> *) a. MonadIO m => IO a -> StateStackT s m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: forall a. IO a -> StateStackT s m a
$cliftIO :: forall s (m :: * -> *) a. MonadIO m => IO a -> StateStackT s m a
IC.MonadIO)

-- | Class of monads which support a state along with a stack for
--   saving and restoring states.
class St.MonadState s m => MonadStateStack s m where
  save    :: m ()   -- ^ Save the current state on the stack
  restore :: m ()   -- ^ Restore the top state from the stack

instance Monad m => St.MonadState s (StateStackT s m) where
  get :: StateStackT s m s
get   = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
St.gets forall a b. (a, b) -> a
fst
  put :: s -> StateStackT s m ()
put s
s = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall a b. (a -> b) -> a -> b
$ (forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
St.modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first) (forall a b. a -> b -> a
const s
s)

instance Monad m => MonadStateStack s (StateStackT s m) where
  save :: StateStackT s m ()
save    = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
St.modify (forall a b. (a, b) -> a
fst forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (:))
  restore :: StateStackT s m ()
restore = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
St.modify forall a b. (a -> b) -> a -> b
$ \(s
cur,[s]
hist) ->
              case [s]
hist of
                []        -> (s
cur,[s]
hist)
                (s
r:[s]
hist') -> (s
r,[s]
hist')

-- | Run a @StateStackT@ computation from an initial state, resulting
--   in a computation of the underlying monad which yields the return
--   value and final state.
runStateStackT :: Monad m => StateStackT s m a -> s -> m (a, s)
runStateStackT :: forall (m :: * -> *) s a.
Monad m =>
StateStackT s m a -> s -> m (a, s)
runStateStackT StateStackT s m a
m s
s = (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second) forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
St.runStateT (s
s,[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateStackT s m a -> StateT (s, [s]) m a
unStateStackT forall a b. (a -> b) -> a -> b
$ StateStackT s m a
m

-- | Like 'runStateStackT', but discard the final state.
evalStateStackT :: Monad m => StateStackT s m a -> s -> m a
evalStateStackT :: forall (m :: * -> *) s a. Monad m => StateStackT s m a -> s -> m a
evalStateStackT StateStackT s m a
m s
s = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a.
Monad m =>
StateStackT s m a -> s -> m (a, s)
runStateStackT StateStackT s m a
m s
s

-- | Like 'runStateStackT', but discard the return value and yield
--   only the final state.
execStateStackT :: Monad m => StateStackT s m a -> s -> m s
execStateStackT :: forall (m :: * -> *) s a. Monad m => StateStackT s m a -> s -> m s
execStateStackT StateStackT s m a
m s
s = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a.
Monad m =>
StateStackT s m a -> s -> m (a, s)
runStateStackT StateStackT s m a
m s
s

type StateStack s a = StateStackT s Identity a

-- | Run a @StateStack@ computation from an initial state, resulting
--   in a pair of the final return value and final state.
runStateStack :: StateStack s a -> s -> (a,s)
runStateStack :: forall s a. StateStack s a -> s -> (a, s)
runStateStack StateStack s a
m s
s = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a.
Monad m =>
StateStackT s m a -> s -> m (a, s)
runStateStackT StateStack s a
m s
s

-- | Like 'runStateStack', but discard the final state.
evalStateStack :: StateStack s a -> s -> a
evalStateStack :: forall s a. StateStack s a -> s -> a
evalStateStack StateStack s a
m s
s = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a. Monad m => StateStackT s m a -> s -> m a
evalStateStackT StateStack s a
m s
s

-- | Like 'runStateStack', but discard the return value and yield
--   only the final state.
execStateStack :: StateStack s a -> s -> s
execStateStack :: forall s a. StateStack s a -> s -> s
execStateStack StateStack s a
m s
s = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a. Monad m => StateStackT s m a -> s -> m s
execStateStackT StateStack s a
m s
s

-- | @StateT@ computations can always be lifted to @StateStackT@
--   computations which do not manipulate the state stack.
liftState :: Monad m => St.StateT s m a -> StateStackT s m a
liftState :: forall (m :: * -> *) s a.
Monad m =>
StateT s m a -> StateStackT s m a
liftState StateT s m a
st = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
St.StateT forall a b. (a -> b) -> a -> b
$ \(s
s,[s]
ss) -> (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second) (forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) [s]
ss) (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
St.runStateT StateT s m a
st s
s)

------------------------------------------------------------
--  Applying monad transformers to MonadStateStack monads
------------------------------------------------------------

instance MonadStateStack s m => MonadStateStack s (ContT r m) where
  save :: ContT r m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: ContT r m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (ExceptT e m) where
  save :: ExceptT e m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: ExceptT e m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (IdentityT m) where
  save :: IdentityT m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: IdentityT m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (MaybeT m) where
  save :: MaybeT m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: MaybeT m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (ReaderT r m) where
  save :: ReaderT r m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: ReaderT r m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (Lazy.StateT s m) where
  save :: StateT s m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: StateT s m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance MonadStateStack s m => MonadStateStack s (Strict.StateT s m) where
  save :: StateT s m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: StateT s m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Lazy.WriterT w m) where
  save :: WriterT w m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: WriterT w m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Strict.WriterT w m) where
  save :: WriterT w m ()
save    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
save
  restore :: WriterT w m ()
restore = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadStateStack s m => m ()
restore

------------------------------------------------------------
--  Applying StateStackT to other monads
------------------------------------------------------------

instance CC.MonadCont m => CC.MonadCont (StateStackT s m) where
  callCC :: forall a b.
((a -> StateStackT s m b) -> StateStackT s m a)
-> StateStackT s m a
callCC (a -> StateStackT s m b) -> StateStackT s m a
c = forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b. MonadCont m => ((a -> m b) -> m a) -> m a
CC.callCC (forall s (m :: * -> *) a. StateStackT s m a -> StateT (s, [s]) m a
unStateStackT forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\a -> StateT (s, [s]) m b
k -> (a -> StateStackT s m b) -> StateStackT s m a
c (forall s (m :: * -> *) a. StateT (s, [s]) m a -> StateStackT s m a
StateStackT forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> StateT (s, [s]) m b
k)))

{-  -- These require UndecidableInstances =(
instance EC.MonadError e m => EC.MonadError e (StateStackT s m) where
  throwError     = lift . EC.throwError
  catchError m h = StateStackT $ EC.catchError (unStateStackT m) (unStateStackT . h)

instance RC.MonadReader r m => RC.MonadReader r (StateStackT s m) where
  ask     = lift RC.ask
  local f = StateStackT . RC.local f . unStateStackT
-}