module Control.NumericalMonad.State.Strict where
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse))
newtype Identity a = Identity { runIdentity :: a }
instance Functor Identity where
fmap f m = Identity (f (runIdentity m))
{-# INLINE fmap #-}
instance Foldable Identity where
foldMap f (Identity x) = f x
{-# INLINE foldMap #-}
instance Traversable Identity where
traverse f (Identity x) = Identity <$> f x
{-# INLINE traverse #-}
instance Applicative Identity where
pure a = Identity a
{-# INLINE pure #-}
Identity f <*> Identity x = Identity (f x)
{-# INLINE (<*>) #-}
instance Monad Identity where
return a = Identity a
{-# INLINE return #-}
m >>= k = k (runIdentity m)
{-# INLINE (>>=)#-}
instance MonadFix Identity where
mfix f = Identity (fix (runIdentity . f))
{-# INLINE mfix #-}
type State s = StateT s Identity
state :: Monad m
=> (s -> (a, s))
-> StateT s m a
state = \f -> StateT (return . f)
{-# INLINE state #-}
runState :: State s a
-> s
-> (a, s)
runState = \ m -> runIdentity . runStateT m
{-# INLINE runState#-}
evalState :: State s a
-> s
-> a
evalState = \m s -> fst (runState m s)
{-# INLINE evalState #-}
execState :: State s a
-> s
-> s
execState = \m s -> snd (runState m s)
{-# INLINE execState#-}
mapState :: ((a, s) -> (b, s)) -> State s a -> State s b
mapState = \ f -> mapStateT (Identity . f . runIdentity)
{-# INLINE mapState #-}
withState :: (s -> s) -> State s a -> State s a
withState = \f st -> withStateT f st
{-# INLINE withState #-}
newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
evalStateT :: (Monad m) => StateT s m a -> s -> m a
evalStateT = \ m s -> do
(a, _) <- runStateT m s
return a
{-# INLINE evalStateT #-}
execStateT :: (Monad m) => StateT s m a -> s -> m s
execStateT = \ m s -> do
(_, s') <- runStateT m s
return s'
{-# INLINE execStateT #-}
mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT = \ f m -> StateT $ f . runStateT m
withStateT :: (s -> s) -> StateT s m a -> StateT s m a
withStateT = \ f m -> StateT $ runStateT m . f
instance (Functor m) => Functor (StateT s m) where
fmap = \ f m -> StateT $ \ s ->
fmap (\ (a, s') -> (f a, s')) $ runStateT m s
{-# INLINE fmap #-}
instance (Functor m, Monad m) => Applicative (StateT s m) where
pure = \ a ->return a
(<*>) = \ a b -> ap a b
instance (Functor m, MonadPlus m) => Alternative (StateT s m) where
empty = mzero
{-# INLINE empty #-}
(<|>) = \ a b -> mplus a b
{-#INLINE (<|>)#-}
instance (Monad m) => Monad (StateT s m) where
{-# INLINE return #-}
return = \ a -> state $ \s -> (a, s)
{-# INLINE (>>=)#-}
(>>=) = \m k -> StateT $ \s -> do
(a, s') <- runStateT m s
runStateT (k a) s'
fail str = StateT $ \_ -> fail str
instance (MonadPlus m) => MonadPlus (StateT s m) where
mzero = StateT $ \_ -> mzero
{-# INLINE mzero #-}
mplus = \ m n -> StateT $ \s -> runStateT m s `mplus` runStateT n s
{-# INLINE mplus #-}
instance (MonadFix m) => MonadFix (StateT s m) where
mfix = \ f -> StateT $ \s -> mfix $ \ ~(a, _) -> runStateT (f a) s
{-# INLINE mfix #-}
instance MonadTrans (StateT s) where
{-#INLINE lift #-}
lift = \ m -> StateT $ \s -> do
a <- m
return (a, s)
instance (MonadIO m) => MonadIO (StateT s m) where
liftIO = lift . liftIO
get :: (Monad m) => StateT s m s
get = state $ \s -> (s, s)
{-# INLINE get #-}
put :: (Monad m) => s -> StateT s m ()
put = \s -> state $ \_ -> ((), s)
{-# INLINE put #-}
modify :: (Monad m) => (s -> s) -> StateT s m ()
modify = \f -> state $ \s -> ((), f s)
{-# INLINE modify #-}
modify' :: (Monad m) => (s -> s) -> StateT s m ()
modify' f = do
s <- get
put $! f s
{-# INLINE modify' #-}
gets :: (Monad m) => (s -> a) -> StateT s m a
gets = \ f -> state $ \s -> (f s, s)
{-# INLINE gets #-}