{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE CPP #-}

#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE QuantifiedConstraints #-}
#endif

{-| Composition of monad transformers. A higher-order version of
    "Data.Functor.Compose".
-}

module Control.Monad.Trans.Compose (
    -- * ComposeT
    ComposeT(ComposeT, getComposeT),
    mapComposeT
   ) where

import Control.Applicative (
    Applicative(pure, (<*>), (*>), (<*)), Alternative(empty, (<|>)) )
import Control.Monad (MonadPlus(mzero, mplus), liftM)
import Control.Monad.Cont.Class (MonadCont(callCC))
import Control.Monad.Error.Class (MonadError(throwError, catchError))
import Control.Monad.Morph (MFunctor(hoist))
import Control.Monad.RWS.Class (MonadRWS)
import Control.Monad.Reader.Class (MonadReader(ask, local, reader))
import Control.Monad.State.Class (MonadState(get, put, state))
import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad.Writer.Class (MonadWriter(writer, tell, listen, pass))
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.Foldable (Foldable(fold, foldMap, foldr, foldl, foldr1, foldl1))
import Data.Traversable (Traversable(traverse, sequenceA, mapM, sequence))
import Prelude hiding (foldr, foldl, foldr1, foldl1, mapM, sequence)

#if __GLASGOW_HASKELL__ >= 800
import Control.Monad.Fail (MonadFail(..))
import qualified Control.Monad.Fail
#endif

infixr 9 `ComposeT`

-- | Composition of monad transformers.
newtype ComposeT (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *) m a
    = ComposeT { getComposeT :: f (g m) a }
  deriving (Eq, Ord, Read, Show)

instance (MFunctor f, MonadTrans f, MonadTrans g) => MonadTrans (ComposeT f g)
  where
    lift = ComposeT . hoist lift . lift

#if __GLASGOW_HASKELL__ >= 806
instance (MFunctor f, MFunctor g, forall m. Monad m => Monad (g m))
    => MFunctor (ComposeT f g) where
    hoist f (ComposeT m) = ComposeT (hoist (hoist f) m)
#endif

instance Functor (f (g m)) => Functor (ComposeT f g m) where
    fmap f (ComposeT m) = ComposeT (fmap f m)

instance Applicative (f (g m)) => Applicative (ComposeT f g m) where
    pure a = ComposeT (pure a)
    ComposeT f <*> ComposeT a = ComposeT (f <*> a)
    ComposeT a *> ComposeT b = ComposeT (a *> b)
    ComposeT a <* ComposeT b = ComposeT (a <* b)

instance Alternative (f (g m)) => Alternative (ComposeT f g m) where
    empty = ComposeT empty
    ComposeT a <|> ComposeT b = ComposeT (a <|> b)

instance Monad (f (g m)) => Monad (ComposeT f g m) where
    return a = ComposeT (return a)
    m >>= f  = ComposeT (getComposeT m >>= \x -> getComposeT (f x))
#if __GLASGOW_HASKELL__ < 808
    fail e   = ComposeT (Prelude.fail e)
#endif

#if __GLASGOW_HASKELL__ >= 800
instance MonadFail (f (g m)) => MonadFail (ComposeT f g m) where
    fail e = ComposeT (Control.Monad.Fail.fail e)
#endif

instance MonadPlus (f (g m)) => MonadPlus (ComposeT f g m) where
    mzero = ComposeT mzero
    ComposeT a `mplus` ComposeT b = ComposeT (a `mplus` b)

instance MonadIO (f (g m)) => MonadIO (ComposeT f g m) where
    liftIO m = ComposeT (liftIO m)

instance Foldable (f (g m)) => Foldable (ComposeT f g m) where
    fold        (ComposeT m) = fold m
    foldMap f   (ComposeT m) = foldMap f   m
    foldr   f a (ComposeT m) = foldr   f a m
    foldl   f a (ComposeT m) = foldl   f a m
    foldr1 f    (ComposeT m) = foldr1  f   m
    foldl1 f    (ComposeT m) = foldl1  f   m

instance Traversable (f (g m)) => Traversable (ComposeT f g m) where
    traverse f (ComposeT m) = fmap  ComposeT (traverse f m)
    sequenceA  (ComposeT m) = fmap  ComposeT (sequenceA  m)
    mapM     f (ComposeT m) = liftM ComposeT (mapM     f m)
    sequence   (ComposeT m) = liftM ComposeT (sequence   m)

instance MonadCont (f (g m)) => MonadCont (ComposeT f g m) where
    callCC f = ComposeT $ callCC $ \c -> getComposeT (f (ComposeT . c))

instance MonadError e (f (g m)) => MonadError e (ComposeT f g m) where
    throwError     = ComposeT . throwError
    catchError m h = ComposeT $ catchError (getComposeT m) (getComposeT . h)

instance MonadRWS r w s (f (g m)) => MonadRWS r w s (ComposeT f g m)

instance MonadReader r (f (g m)) => MonadReader r (ComposeT f g m) where
    ask    = ComposeT ask
    local  = mapComposeT . local
    reader = ComposeT . reader

instance MonadState s (f (g m)) => MonadState s (ComposeT f g m) where
    get   = ComposeT get
    put   = ComposeT . put
    state = ComposeT . state

instance MonadWriter w (f (g m)) => MonadWriter w (ComposeT f g m) where
    writer = ComposeT . writer
    tell   = ComposeT . tell
    listen = mapComposeT listen
    pass   = mapComposeT pass

-- | Transform the computation inside a 'ComposeT'.
mapComposeT :: (f (g m) a -> p (q n) b) -> ComposeT f g m a -> ComposeT p q n b
mapComposeT f = ComposeT . f . getComposeT