------------------------------------------------------------------------
-- |
-- Module           : Lang.Crucible.Utils.StateContT
-- Description      : A monad providing continuations and state.
-- Copyright        : (c) Galois, Inc 2013-2014
-- License          : BSD3
-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
-- Stability        : provisional
--
-- This module defines a monad with continuations and state.  By using this
-- instead of a MTL StateT and ContT transformer stack, one can have a
-- continuation that implements MonadCont and MonadState, yet never
-- returns the final state.  This also wraps MonadST.
------------------------------------------------------------------------
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
module Lang.Crucible.Utils.StateContT
  ( StateContT(..)
    -- * Re-exports
  , Control.Monad.Cont.Class.MonadCont(..)
  , Control.Monad.State.Class.MonadState(..)
  ) where

import Control.Monad.Cont.Class   (MonadCont(..))
import Control.Monad.IO.Class     (MonadIO(..))
import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.State.Class  (MonadState(..))
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.Catch ( MonadThrow(..), MonadCatch(..) )

import What4.Utils.MonadST

-- | A monad transformer that provides @MonadCont@ and @MonadState@.
newtype StateContT s r m a
      = StateContT { forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT :: (a -> s -> m r)
                                   -> s
                                   -> m r
                   }

fmapStateContT :: (a -> b) -> StateContT s r m a -> StateContT s r m b
fmapStateContT :: forall a b s r (m :: Type -> Type).
(a -> b) -> StateContT s r m a -> StateContT s r m b
fmapStateContT = \a -> b
f StateContT s r m a
m -> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((b -> s -> m r) -> s -> m r) -> StateContT s r m b)
-> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall a b. (a -> b) -> a -> b
$ \b -> s -> m r
c -> StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m a
m (\a
v s
s -> (b -> s -> m r
c (b -> s -> m r) -> b -> s -> m r
forall a b. (a -> b) -> a -> b
$! a -> b
f a
v) s
s)
{-# INLINE fmapStateContT #-}

applyStateContT :: StateContT s r m (a -> b) -> StateContT s r m a -> StateContT s r m b
applyStateContT :: forall s r (m :: Type -> Type) a b.
StateContT s r m (a -> b)
-> StateContT s r m a -> StateContT s r m b
applyStateContT = \StateContT s r m (a -> b)
mf StateContT s r m a
mv ->
  ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((b -> s -> m r) -> s -> m r) -> StateContT s r m b)
-> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall a b. (a -> b) -> a -> b
$ \b -> s -> m r
c ->
    StateContT s r m (a -> b) -> ((a -> b) -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m (a -> b)
mf (\a -> b
f -> StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m a
mv (\a
v s
s -> (b -> s -> m r
c (b -> s -> m r) -> b -> s -> m r
forall a b. (a -> b) -> a -> b
$! a -> b
f a
v) s
s))
{-# INLINE applyStateContT #-}

returnStateContT :: a -> StateContT s r m a
returnStateContT :: forall a s r (m :: Type -> Type). a -> StateContT s r m a
returnStateContT = \a
v -> a -> StateContT s r m a -> StateContT s r m a
forall a b. a -> b -> b
seq a
v (StateContT s r m a -> StateContT s r m a)
-> StateContT s r m a -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
c -> a -> s -> m r
c a
v
{-# INLINE returnStateContT #-}

bindStateContT :: StateContT s r m a -> (a -> StateContT s r m b) -> StateContT s r m b
bindStateContT :: forall s r (m :: Type -> Type) a b.
StateContT s r m a
-> (a -> StateContT s r m b) -> StateContT s r m b
bindStateContT = \StateContT s r m a
m a -> StateContT s r m b
n -> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((b -> s -> m r) -> s -> m r) -> StateContT s r m b)
-> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall a b. (a -> b) -> a -> b
$ \b -> s -> m r
c -> StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m a
m (\a
a -> StateContT s r m b -> (b -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT (a -> StateContT s r m b
n a
a) b -> s -> m r
c)
{-# INLINE bindStateContT #-}

instance Functor (StateContT s r m) where
  fmap :: forall a b. (a -> b) -> StateContT s r m a -> StateContT s r m b
fmap = (a -> b) -> StateContT s r m a -> StateContT s r m b
forall a b s r (m :: Type -> Type).
(a -> b) -> StateContT s r m a -> StateContT s r m b
fmapStateContT

instance Applicative (StateContT s r m) where
  pure :: forall a. a -> StateContT s r m a
pure  = a -> StateContT s r m a
forall a s r (m :: Type -> Type). a -> StateContT s r m a
returnStateContT
  <*> :: forall a b.
StateContT s r m (a -> b)
-> StateContT s r m a -> StateContT s r m b
(<*>) = StateContT s r m (a -> b)
-> StateContT s r m a -> StateContT s r m b
forall s r (m :: Type -> Type) a b.
StateContT s r m (a -> b)
-> StateContT s r m a -> StateContT s r m b
applyStateContT

instance Monad (StateContT s r m) where
  >>= :: forall a b.
StateContT s r m a
-> (a -> StateContT s r m b) -> StateContT s r m b
(>>=) = StateContT s r m a
-> (a -> StateContT s r m b) -> StateContT s r m b
forall s r (m :: Type -> Type) a b.
StateContT s r m a
-> (a -> StateContT s r m b) -> StateContT s r m b
bindStateContT

instance MonadFail m => MonadFail (StateContT s r m) where
  fail :: forall a. String -> StateContT s r m a
fail = \String
msg -> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
_ s
_ -> String -> m r
forall a. String -> m a
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail String
msg

instance MonadCont (StateContT s r m) where
  callCC :: forall a b.
((a -> StateContT s r m b) -> StateContT s r m a)
-> StateContT s r m a
callCC (a -> StateContT s r m b) -> StateContT s r m a
f = ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
c -> StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT ((a -> StateContT s r m b) -> StateContT s r m a
f (\a
a -> a -> StateContT s r m b -> StateContT s r m b
forall a b. a -> b -> b
seq a
a (StateContT s r m b -> StateContT s r m b)
-> StateContT s r m b -> StateContT s r m b
forall a b. (a -> b) -> a -> b
$ ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((b -> s -> m r) -> s -> m r) -> StateContT s r m b)
-> ((b -> s -> m r) -> s -> m r) -> StateContT s r m b
forall a b. (a -> b) -> a -> b
$ \b -> s -> m r
_ s
s -> a -> s -> m r
c a
a s
s)) a -> s -> m r
c

instance MonadState s (StateContT s r m) where
  get :: StateContT s r m s
get = ((s -> s -> m r) -> s -> m r) -> StateContT s r m s
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((s -> s -> m r) -> s -> m r) -> StateContT s r m s)
-> ((s -> s -> m r) -> s -> m r) -> StateContT s r m s
forall a b. (a -> b) -> a -> b
$ \s -> s -> m r
c s
s -> s -> s -> m r
c s
s s
s
  put :: s -> StateContT s r m ()
put = \s
s -> s -> StateContT s r m () -> StateContT s r m ()
forall a b. a -> b -> b
seq s
s (StateContT s r m () -> StateContT s r m ())
-> StateContT s r m () -> StateContT s r m ()
forall a b. (a -> b) -> a -> b
$ ((() -> s -> m r) -> s -> m r) -> StateContT s r m ()
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((() -> s -> m r) -> s -> m r) -> StateContT s r m ())
-> ((() -> s -> m r) -> s -> m r) -> StateContT s r m ()
forall a b. (a -> b) -> a -> b
$ \() -> s -> m r
c s
_ -> () -> s -> m r
c () s
s
  state :: forall a. (s -> (a, s)) -> StateContT s r m a
state s -> (a, s)
f = ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
c s
s -> let (a
r,s
s') = s -> (a, s)
f s
s in (a -> s -> m r
c (a -> s -> m r) -> a -> s -> m r
forall a b. (a -> b) -> a -> b
$! a
r) (s -> m r) -> s -> m r
forall a b. (a -> b) -> a -> b
$! s
s'

instance MonadTrans (StateContT s r) where
  lift :: forall (m :: Type -> Type) a. Monad m => m a -> StateContT s r m a
lift = \m a
m -> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
c s
s -> m a
m m a -> (a -> m r) -> m r
forall a b. m a -> (a -> m b) -> m b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> a -> m r -> m r
forall a b. a -> b -> b
seq a
v (a -> s -> m r
c a
v s
s)

instance MonadIO m => MonadIO (StateContT s r m) where
  liftIO :: forall a. IO a -> StateContT s r m a
liftIO = m a -> StateContT s r m a
forall (m :: Type -> Type) a. Monad m => m a -> StateContT s r m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateContT s r m a)
-> (IO a -> m a) -> IO a -> StateContT s r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO

instance MonadST s m => MonadST s (StateContT t r m) where
  liftST :: forall a. ST s a -> StateContT t r m a
liftST = m a -> StateContT t r m a
forall (m :: Type -> Type) a. Monad m => m a -> StateContT t r m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateContT t r m a)
-> (ST s a -> m a) -> ST s a -> StateContT t r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST s a -> m a
forall a. ST s a -> m a
forall s (m :: Type -> Type) a. MonadST s m => ST s a -> m a
liftST

instance MonadReader v m => MonadReader v (StateContT s r m) where
  ask :: StateContT s r m v
ask = m v -> StateContT s r m v
forall (m :: Type -> Type) a. Monad m => m a -> StateContT s r m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m v
forall r (m :: Type -> Type). MonadReader r m => m r
ask
  local :: forall a. (v -> v) -> StateContT s r m a -> StateContT s r m a
local v -> v
f StateContT s r m a
m = ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
c s
s -> (v -> v) -> m r -> m r
forall a. (v -> v) -> m a -> m a
forall r (m :: Type -> Type) a.
MonadReader r m =>
(r -> r) -> m a -> m a
local v -> v
f (StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m a
m a -> s -> m r
c s
s)

instance MonadThrow m => MonadThrow (StateContT s r m) where
  throwM :: forall e a. (HasCallStack, Exception e) => e -> StateContT s r m a
throwM e
e = ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (\a -> s -> m r
_k s
_s -> e -> m r
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: Type -> Type) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM e
e)

instance MonadCatch m => MonadCatch (StateContT s r m) where
  catch :: forall e a.
(HasCallStack, Exception e) =>
StateContT s r m a
-> (e -> StateContT s r m a) -> StateContT s r m a
catch StateContT s r m a
m e -> StateContT s r m a
hdl =
    ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall s r (m :: Type -> Type) a.
((a -> s -> m r) -> s -> m r) -> StateContT s r m a
StateContT (((a -> s -> m r) -> s -> m r) -> StateContT s r m a)
-> ((a -> s -> m r) -> s -> m r) -> StateContT s r m a
forall a b. (a -> b) -> a -> b
$ \a -> s -> m r
k s
s ->
      m r -> (e -> m r) -> m r
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: Type -> Type) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
catch
        (StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT StateContT s r m a
m a -> s -> m r
k s
s)
        (\e
e -> StateContT s r m a -> (a -> s -> m r) -> s -> m r
forall s r (m :: Type -> Type) a.
StateContT s r m a -> (a -> s -> m r) -> s -> m r
runStateContT (e -> StateContT s r m a
hdl e
e) a -> s -> m r
k s
s)