{-# LANGUAGE CPP, MultiParamTypeClasses, FunctionalDependencies,
UndecidableInstances, FlexibleInstances #-}
{-# LANGUAGE DataKinds, TypeFamilies, TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Safe #-}
module MonadLib (
Id, Lift, IdT, ReaderT, WriterT,
StateT,
ExceptionT,
ChoiceT, ContT,
MonadT(..), BaseM(..),
ReaderM(..), WriterM(..), StateM(..), ExceptionM(..), ContM(..), AbortM(..),
Label, labelCC, labelCC_, jump, labelC, callCC,
runId, runLift,
runIdT, runReaderT, runWriterT,
runStateT, runExceptionT, runContT,
runChoiceT, findOne, findAll,
RunM(..),
RunReaderM(..), RunWriterM(..), RunExceptionM(..),
asks, puts, sets, sets_, raises,
mapReader, mapWriter, mapException,
handle,
WithBase,
module Control.Monad
) where
#if __GLASGOW_HASKELL__ < 800
import Data.Monoid
#endif
import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.ST (ST)
import qualified Control.Exception as IO (throwIO,try)
#ifdef USE_BASE3
import qualified Control.Exception as IO (Exception)
#else
import qualified Control.Exception as IO (SomeException)
#endif
import System.Exit(ExitCode,exitWith)
import Data.Kind(Type)
import Prelude hiding (Ordering(..))
#if __GLASGOW_HASKELL__ >= 800
import qualified Control.Monad.Fail as MF
import Control.Monad.Fail(MonadFail)
#endif
newtype Id a = I a
data Lift a = L a
newtype IdT m a = IT (m a)
newtype ReaderT i m a = R (i -> m a)
newtype WriterT i m a = W { unW :: m (P a i) }
data P a i = P a !i
newtype StateT i m a = S (i -> m (a,i))
newtype ExceptionT i m a = X (m (Either i a))
data ChoiceT m a = NoAnswer
| Answer a
| Choice (ChoiceT m a) (ChoiceT m a)
| ChoiceEff (m (ChoiceT m a))
newtype ContT i m a = C ((a -> m i) -> m i)
runId :: Id a -> a
runId (I a) = a
runLift :: Lift a -> a
runLift (L a) = a
runIdT :: IdT m a -> m a
runIdT (IT a) = a
runReaderT :: i -> ReaderT i m a -> m a
runReaderT i (R m) = m i
runWriterT :: (Monad m) => WriterT i m a -> m (a,i)
runWriterT (W m) = liftM to_pair m
where to_pair ~(P a w) = (a,w)
runStateT :: i -> StateT i m a -> m (a,i)
runStateT i (S m) = m i
runExceptionT :: ExceptionT i m a -> m (Either i a)
runExceptionT (X m) = m
runChoiceT :: (Monad m) => ChoiceT m a -> m (Maybe (a,ChoiceT m a))
runChoiceT (Answer a) = return (Just (a,NoAnswer))
runChoiceT NoAnswer = return Nothing
runChoiceT (Choice l r) = do x <- runChoiceT l
case x of
Nothing -> runChoiceT r
Just (a,l1) -> return (Just (a,Choice l1 r))
runChoiceT (ChoiceEff m) = runChoiceT =<< m
findOne :: (Monad m) => ChoiceT m a -> m (Maybe a)
findOne m = fmap fst `liftM` runChoiceT m
findAll :: (Monad m) => ChoiceT m a -> m [a]
findAll m = all_res =<< runChoiceT m
where all_res Nothing = return []
all_res (Just (a,as)) = (a:) `liftM` findAll as
runContT :: (a -> m i) -> ContT i m a -> m i
runContT i (C m) = m i
class Monad m => RunM m a r | m a -> r where
runM :: m a -> r
instance RunM Id a a where
runM = runId
instance RunM Lift a a where
runM = runLift
instance RunM IO a (IO a) where
runM = id
instance RunM m a r => RunM (IdT m) a r where
runM = runM . runIdT
instance RunM m a r => RunM (ReaderT i m) a (i -> r) where
runM m i = runM (runReaderT i m)
instance (Monoid i, RunM m (a,i) r) => RunM (WriterT i m) a r where
runM = runM . runWriterT
instance RunM m (a,i) r => RunM (StateT i m) a (i -> r) where
runM m i = runM (runStateT i m)
instance RunM m (Either i a) r => RunM (ExceptionT i m) a r where
runM = runM . runExceptionT
instance RunM m i r => RunM (ContT i m) a ((a -> m i) -> r) where
runM m k = runM (runContT k m)
instance RunM m (Maybe (a,ChoiceT m a)) r => RunM (ChoiceT m) a r where
runM = runM . runChoiceT
class MonadT t where
lift :: (Monad m) => m a -> t m a
instance MonadT IdT where lift m = IT m
instance MonadT (ReaderT i) where lift m = R (\_ -> m)
instance MonadT (StateT i) where lift m = S (\s -> liftM (\a -> (a,s)) m)
instance (Monoid i)
=> MonadT (WriterT i) where lift m = W (liftM (\a -> P a mempty) m)
instance MonadT (ExceptionT i) where lift m = X (liftM Right m)
instance MonadT ChoiceT where lift m = ChoiceEff (liftM Answer m)
instance MonadT (ContT i) where lift m = C (\k -> m >>= k)
t_inBase :: (MonadT t, BaseM m n) => n a -> t m a
t_inBase m = lift (inBase m)
t_return :: (MonadT t, Monad m) => a -> t m a
t_return x = lift (return x)
t_fail :: (MonadT t, Monad m) => String -> t m a
t_fail x = lift (fail x)
t_fail' :: (MonadT t, MonadFail m) => String -> t m a
t_fail' x = lift (MF.fail x)
t_mzero :: (MonadT t, MonadPlus m) => t m a
t_mzero = lift mzero
t_ask :: (MonadT t, ReaderM m i) => t m i
t_ask = lift ask
t_put :: (MonadT t, WriterM m i) => i -> t m ()
t_put x = lift (put x)
t_get :: (MonadT t, StateM m i) => t m i
t_get = lift get
t_set :: (MonadT t, StateM m i) => i -> t m ()
t_set i = lift (set i)
t_raise :: (MonadT t, ExceptionM m i) => i -> t m a
t_raise i = lift (raise i)
t_abort :: (MonadT t, AbortM m i) => i -> t m a
t_abort i = lift (abort i)
class (Monad m, Monad n) => BaseM m n | m -> n where
inBase :: n a -> m a
instance BaseM IO IO where inBase = id
instance BaseM Maybe Maybe where inBase = id
instance BaseM [] [] where inBase = id
instance BaseM Id Id where inBase = id
instance BaseM Lift Lift where inBase = id
instance BaseM (ST s) (ST s) where inBase = id
instance (BaseM m n) => BaseM (IdT m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ReaderT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (StateT i m) n where inBase = t_inBase
instance (BaseM m n,Monoid i)
=> BaseM (WriterT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ExceptionT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ChoiceT m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ContT i m) n where inBase = t_inBase
instance Monad Id where
return x = I x
fail x = error x
m >>= k = k (runId m)
instance Monad Lift where
return x = L x
fail x = error x
L x >>= k = k x
instance (Monad m) => Monad (IdT m) where
return = t_return
fail = t_fail
m >>= k = IT (runIdT m >>= (runIdT . k))
instance (Monad m) => Monad (ReaderT i m) where
return = t_return
fail = t_fail
m >>= k = R (\r -> runReaderT r m >>= \a -> runReaderT r (k a))
instance (Monad m) => Monad (StateT i m) where
return = t_return
fail = t_fail
m >>= k = S (\s -> runStateT s m >>= \ ~(a,s') -> runStateT s' (k a))
instance (Monad m,Monoid i) => Monad (WriterT i m) where
return = t_return
fail = t_fail
m >>= k = W $ unW m >>= \ ~(P a w1) ->
unW (k a) >>= \ ~(P b w2) ->
return (P b (mappend w1 w2))
instance (Monad m) => Monad (ExceptionT i m) where
return = t_return
fail = t_fail
m >>= k = X $ runExceptionT m >>= \e ->
case e of
Left x -> return (Left x)
Right a -> runExceptionT (k a)
instance (Monad m) => Monad (ChoiceT m) where
return x = Answer x
fail x = lift (fail x)
Answer a >>= k = k a
NoAnswer >>= _ = NoAnswer
Choice m1 m2 >>= k = Choice (m1 >>= k) (m2 >>= k)
ChoiceEff m >>= k = ChoiceEff (liftM (>>= k) m)
instance (Monad m) => Monad (ContT i m) where
return = t_return
fail = t_fail
m >>= k = C $ \c -> runContT (\a -> runContT c (k a)) m
instance Functor Id where fmap = liftM
instance Functor Lift where fmap = liftM
instance (Monad m) => Functor (IdT m) where fmap = liftM
instance (Monad m) => Functor (ReaderT i m) where fmap = liftM
instance (Monad m) => Functor (StateT i m) where fmap = liftM
instance (Monad m,Monoid i) => Functor (WriterT i m) where fmap = liftM
instance (Monad m) => Functor (ExceptionT i m) where fmap = liftM
instance (Monad m) => Functor (ChoiceT m) where fmap = liftM
instance (Monad m) => Functor (ContT i m) where fmap = liftM
instance Applicative Id where (<*>) = ap; pure = return
instance Applicative Lift where (<*>) = ap; pure = return
instance (Monad m) => Applicative (IdT m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ReaderT i m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (StateT i m) where (<*>) = ap; pure = return
instance (Monad m,Monoid i)
=> Applicative (WriterT i m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ExceptionT i m)
where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ChoiceT m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ContT i m) where (<*>) = ap; pure = return
instance (MonadPlus m)
=> Alternative (IdT m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ReaderT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (StateT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m,Monoid i)
=> Alternative (WriterT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ExceptionT i m) where (<|>) = mplus; empty = mzero
instance (Monad m)
=> Alternative (ChoiceT m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ContT i m) where (<|>) = mplus; empty = mzero
instance MonadFix Id where
mfix f = let m = f (runId m) in m
instance MonadFix Lift where
mfix f = let m = f (runLift m) in m
instance (MonadFix m) => MonadFix (IdT m) where
mfix f = IT (mfix (runIdT . f))
instance (MonadFix m) => MonadFix (ReaderT i m) where
mfix f = R $ \r -> mfix (runReaderT r . f)
instance (MonadFix m) => MonadFix (StateT i m) where
mfix f = S $ \s -> mfix (runStateT s . f . fst)
instance (MonadFix m,Monoid i) => MonadFix (WriterT i m) where
mfix f = W $ mfix (unW . f . val)
where val ~(P a _) = a
instance (MonadFix m) => MonadFix (ExceptionT i m) where
mfix f = X $ mfix (runExceptionT . f . fromRight)
where fromRight (Right a) = a
fromRight _ = error "ExceptionT: mfix looped."
instance (MonadPlus m) => MonadPlus (IdT m) where
mzero = t_mzero
mplus (IT m) (IT n) = IT (mplus m n)
instance (MonadPlus m) => MonadPlus (ReaderT i m) where
mzero = t_mzero
mplus (R m) (R n) = R (\r -> mplus (m r) (n r))
instance (MonadPlus m) => MonadPlus (StateT i m) where
mzero = t_mzero
mplus (S m) (S n) = S (\s -> mplus (m s) (n s))
instance (MonadPlus m,Monoid i) => MonadPlus (WriterT i m) where
mzero = t_mzero
mplus (W m) (W n) = W (mplus m n)
instance (MonadPlus m) => MonadPlus (ExceptionT i m) where
mzero = t_mzero
mplus (X m) (X n) = X (mplus m n)
instance (Monad m) => MonadPlus (ChoiceT m) where
mzero = NoAnswer
mplus m n = Choice m n
instance (MonadPlus m) => MonadPlus (ContT i m) where
mzero = t_mzero
mplus (C m) (C n) = C (\k -> m k `mplus` n k)
class (Monad m) => ReaderM m i | m -> i where
ask :: m i
instance (Monad m) => ReaderM (ReaderT i m) i where
ask = R return
instance (ReaderM m j) => ReaderM (IdT m) j where ask = t_ask
instance (ReaderM m j,Monoid i)
=> ReaderM (WriterT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (StateT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ExceptionT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ChoiceT m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ContT i m) j where ask = t_ask
class (Monad m) => WriterM m i | m -> i where
put :: i -> m ()
instance (Monad m,Monoid i) => WriterM (WriterT i m) i where
put x = W (return (P () x))
instance (WriterM m j) => WriterM (IdT m) j where put = t_put
instance (WriterM m j) => WriterM (ReaderT i m) j where put = t_put
instance (WriterM m j) => WriterM (StateT i m) j where put = t_put
instance (WriterM m j) => WriterM (ExceptionT i m) j where put = t_put
instance (WriterM m j) => WriterM (ChoiceT m) j where put = t_put
instance (WriterM m j) => WriterM (ContT i m) j where put = t_put
class (Monad m) => StateM m i | m -> i where
get :: m i
set :: i -> m ()
instance (Monad m) => StateM (StateT i m) i where
get = S (\s -> return (s,s))
set s = S (\_ -> return ((),s))
instance (StateM m j) => StateM (IdT m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ReaderT i m) j where
get = t_get; set = t_set
instance (StateM m j,Monoid i) => StateM (WriterT i m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ExceptionT i m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ChoiceT m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ContT i m) j where
get = t_get; set = t_set
class (Monad m) => ExceptionM m i | m -> i where
raise :: i -> m a
#ifdef USE_BASE3
instance ExceptionM IO IO.Exception where
raise = IO.throwIO
#else
instance ExceptionM IO IO.SomeException where
raise = IO.throwIO
#endif
instance (Monad m) => ExceptionM (ExceptionT i m) i where
raise x = X (return (Left x))
instance (ExceptionM m j) => ExceptionM (IdT m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ReaderT i m) j where
raise = t_raise
instance (ExceptionM m j,Monoid i) => ExceptionM (WriterT i m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (StateT i m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ChoiceT m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ContT i m) j where
raise = t_raise
class Monad m => ContM m where
callWithCC :: ((a -> Label m) -> m a) -> m a
liftJump :: (ContM m, MonadT t) =>
(a -> b) ->
((a -> Label (t m)) -> t m a) ->
((b -> Label m ) -> t m a)
liftJump ans f l = f $ \a -> Lab (lift $ jump $ l $ ans a)
instance (ContM m) => ContM (IdT m) where
callWithCC f = IT $ callWithCC $ \k -> runIdT $ liftJump id f k
instance (ContM m) => ContM (ReaderT i m) where
callWithCC f = R $ \r -> callWithCC $ \k -> runReaderT r $ liftJump id f k
instance (ContM m) => ContM (StateT i m) where
callWithCC f = S $ \s -> callWithCC $ \k -> runStateT s $ liftJump (ans s) f k
where ans s a = (a,s)
instance (ContM m,Monoid i) => ContM (WriterT i m) where
callWithCC f = W $ callWithCC $ \k -> unW $ liftJump (`P` mempty) f k
instance (ContM m) => ContM (ExceptionT i m) where
callWithCC f = X $ callWithCC $ \k -> runExceptionT $ liftJump Right f k
instance (ContM m) => ContM (ChoiceT m) where
callWithCC f = ChoiceEff $ callWithCC $ \k -> return $ liftJump Answer f k
instance (Monad m) => ContM (ContT i m) where
callWithCC f = C $ \k -> runContT k $ f $ \a -> Lab (C $ \_ -> k a)
class (ReaderM m i) => RunReaderM m i | m -> i where
local :: i -> m a -> m a
instance (Monad m) => RunReaderM (ReaderT i m) i where
local i m = lift (runReaderT i m)
instance (RunReaderM m j) => RunReaderM (IdT m) j where
local i (IT m) = IT (local i m)
instance (RunReaderM m j,Monoid i) => RunReaderM (WriterT i m) j where
local i (W m) = W (local i m)
instance (RunReaderM m j) => RunReaderM (StateT i m) j where
local i (S m) = S (local i . m)
instance (RunReaderM m j) => RunReaderM (ExceptionT i m) j where
local i (X m) = X (local i m)
instance (RunReaderM m j) => RunReaderM (ContT i m) j where
local i (C m) = C (local i . m)
class WriterM m i => RunWriterM m i | m -> i where
collect :: m a -> m (a,i)
instance (Monad m,Monoid i) => RunWriterM (WriterT i m) i where
collect m = lift (runWriterT m)
instance (RunWriterM m j) => RunWriterM (IdT m) j where
collect (IT m) = IT (collect m)
instance (RunWriterM m j) => RunWriterM (ReaderT i m) j where
collect (R m) = R (collect . m)
instance (RunWriterM m j) => RunWriterM (StateT i m) j where
collect (S m) = S (liftM swap . collect . m)
where swap (~(a,s),w) = ((a,w),s)
instance (RunWriterM m j) => RunWriterM (ExceptionT i m) j where
collect (X m) = X (liftM swap (collect m))
where swap (Right a,w) = Right (a,w)
swap (Left x,_) = Left x
instance (RunWriterM m j, MonadFix m) => RunWriterM (ContT i m) j where
collect (C m) = C $ \k -> fst `liftM`
mfix (\ ~(_,w) -> collect (m (\a -> k (a,w))))
class ExceptionM m i => RunExceptionM m i | m -> i where
try :: m a -> m (Either i a)
#ifdef USE_BASE3
instance RunExceptionM IO IO.Exception where
try = IO.try
#else
instance RunExceptionM IO IO.SomeException where
try = IO.try
#endif
instance (Monad m) => RunExceptionM (ExceptionT i m) i where
try m = lift (runExceptionT m)
instance (RunExceptionM m i) => RunExceptionM (IdT m) i where
try (IT m) = IT (try m)
instance (RunExceptionM m i) => RunExceptionM (ReaderT j m) i where
try (R m) = R (try . m)
instance (RunExceptionM m i,Monoid j) => RunExceptionM (WriterT j m) i where
try (W m) = W (liftM swap (try m))
where swap (Right (P a w)) = P (Right a) w
swap (Left e) = P (Left e) mempty
instance (RunExceptionM m i) => RunExceptionM (StateT j m) i where
try (S m) = S (\s -> liftM (swap s) (try (m s)))
where swap _ (Right ~(a,s)) = (Right a,s)
swap s (Left e) = (Left e, s)
class Monad m => AbortM m i where
abort :: i -> m a
instance Monad m => AbortM (ContT i m) i where
abort i = C (\_ -> return i)
instance AbortM IO ExitCode where
abort = exitWith
instance AbortM m i => AbortM (IdT m) i where abort = t_abort
instance AbortM m i => AbortM (ReaderT j m) i where abort = t_abort
instance (AbortM m i,Monoid j)
=> AbortM (WriterT j m) i where abort = t_abort
instance AbortM m i => AbortM (StateT j m) i where abort = t_abort
instance AbortM m i => AbortM (ExceptionT j m) i where abort = t_abort
instance AbortM m i => AbortM (ChoiceT m) i where abort = t_abort
newtype Label m = Lab (forall b. m b)
labelCC :: (ContM m) => a -> m (a, a -> Label m)
labelCC x = callWithCC (\l -> let label a = Lab (jump (l (a, label)))
in return (x, label))
labelCC_ :: forall m. (ContM m) => m (Label m)
labelCC_ = callWithCC $ \k -> let x :: m a
x = jump (k (Lab x))
in x
callCC :: ContM m => ((a -> m b) -> m a) -> m a
callCC f = callWithCC $ \l -> f $ \a -> jump $ l a
labelC :: (forall b. m b) -> Label m
labelC k = Lab k
jump :: Label m -> m a
jump (Lab k) = k
asks :: ReaderM m r => (r -> a) -> m a
asks f = do r <- ask
return (f r)
puts :: WriterM m w => (a,w) -> m a
puts ~(a,w) = put w >> return a
sets :: StateM m s => (s -> (a,s)) -> m a
sets f = do s <- get
let (a,s1) = f s
set s1
return a
sets_ :: StateM m s => (s -> s) -> m ()
sets_ f = do s <- get
set (f s)
raises :: ExceptionM m x => Either x a -> m a
raises (Right a) = return a
raises (Left x) = raise x
mapReader :: RunReaderM m r => (r -> r) -> m a -> m a
mapReader f m = do r <- ask
local (f r) m
mapWriter :: RunWriterM m w => (w -> w) -> m a -> m a
mapWriter f m = do ~(a,w) <- collect m
put (f w)
return a
mapException :: RunExceptionM m x => (x -> x) -> m a -> m a
mapException f m = do r <- try m
case r of
Right a -> return a
Left x -> raise (f x)
handle :: RunExceptionM m x => m a -> (x -> m a) -> m a
handle m f = do r <- try m
case r of
Right a -> return a
Left x -> f x
type family WithBase base layers :: Type -> Type where
WithBase b '[] = b
WithBase b (f ': fs) = f (WithBase b fs)
#if __GLASGOW_HASKELL__ >= 800
instance MonadFail m => MonadFail (IdT m) where fail = t_fail'
instance MonadFail m => MonadFail (ReaderT i m) where fail = t_fail'
instance (Monoid i, MonadFail m)
=> MonadFail (WriterT i m) where fail = t_fail'
instance MonadFail m => MonadFail (StateT i m) where fail = t_fail'
instance MonadFail m => MonadFail (ExceptionT i m) where fail = t_fail'
instance MonadFail m => MonadFail (ChoiceT m) where fail = t_fail'
instance MonadFail m => MonadFail (ContT i m) where fail = t_fail'
#endif