#include "free-common.h"
module Control.Monad.Trans.Free.Church
(
FT(..)
, F, free, runF
, improveT
, toFT, fromFT
, iterT
, iterTM
, hoistFT
, transFT
, joinFT
, cutoff
, improve
, fromF, toF
, retract
, retractT
, iter
, iterM
, MonadFree(..)
, liftF
) where
import Control.Applicative
import Control.Category ((<<<), (>>>))
import Control.Monad
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad.Identity
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
import Control.Monad.Cont.Class
import Control.Monad.Free.Class
import Control.Monad.Trans.Free (FreeT(..), FreeF(..), Free)
import qualified Control.Monad.Trans.Free as FreeT
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Functor.Bind hiding (join)
import Data.Functor.Classes.Compat
#if !(MIN_VERSION_base(4,8,0))
import Data.Foldable (Foldable)
import Data.Traversable (Traversable)
#endif
newtype FT f m a = FT { runFT :: forall r. (a -> m r) -> (forall x. (x -> m r) -> f x -> m r) -> m r }
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Functor f, Monad m, Eq1 f, Eq1 m) => Eq1 (FT f m) where
liftEq eq x y = liftEq eq (fromFT x) (fromFT y)
instance (Functor f, Monad m, Ord1 f, Ord1 m) => Ord1 (FT f m) where
liftCompare cmp x y= liftCompare cmp (fromFT x) (fromFT y)
#else
instance ( Functor f, Monad m, Eq1 f, Eq1 m
# if !(MIN_VERSION_base(4,8,0))
, Functor m
# endif
) => Eq1 (FT f m) where
eq1 x y = eq1 (fromFT x) (fromFT y)
instance ( Functor f, Monad m, Ord1 f, Ord1 m
# if !(MIN_VERSION_base(4,8,0))
, Functor m
# endif
) => Ord1 (FT f m) where
compare1 x y = compare1 (fromFT x) (fromFT y)
#endif
instance (Eq1 (FT f m), Eq a) => Eq (FT f m a) where
(==) = eq1
instance (Ord1 (FT f m), Ord a) => Ord (FT f m a) where
compare = compare1
instance Functor (FT f m) where
fmap f (FT k) = FT $ \a fr -> k (a . f) fr
instance Apply (FT f m) where
(<.>) = (<*>)
instance Applicative (FT f m) where
pure a = FT $ \k _ -> k a
FT fk <*> FT ak = FT $ \b fr -> fk (\e -> ak (\d -> b (e d)) fr) fr
instance Bind (FT f m) where
(>>-) = (>>=)
instance Monad (FT f m) where
return = pure
FT fk >>= f = FT $ \b fr -> fk (\d -> runFT (f d) b fr) fr
instance MonadFree f (FT f m) where
wrap f = FT (\kp kf -> kf (\ft -> runFT ft kp kf) f)
instance MonadTrans (FT f) where
lift m = FT (\a _ -> m >>= a)
instance Alternative m => Alternative (FT f m) where
empty = FT (\_ _ -> empty)
FT k1 <|> FT k2 = FT $ \a fr -> k1 a fr <|> k2 a fr
instance MonadPlus m => MonadPlus (FT f m) where
mzero = FT (\_ _ -> mzero)
mplus (FT k1) (FT k2) = FT $ \a fr -> k1 a fr `mplus` k2 a fr
instance (Foldable f, Foldable m, Monad m) => Foldable (FT f m) where
foldr f r xs = F.foldr (<<<) id inner r
where
inner = runFT xs (return . f) (\xg xf -> F.foldr (liftM2 (<<<) . xg) (return id) xf)
#if MIN_VERSION_base(4,6,0)
foldl' f z xs = F.foldl' (!>>>) id inner z
where
(!>>>) h g = \r -> g $! h r
inner = runFT xs (return . flip f) (\xg xf -> F.foldr (liftM2 (>>>) . xg) (return id) xf)
#endif
instance (Monad m, Traversable m, Traversable f) => Traversable (FT f m) where
traverse f (FT k) = fmap (join . lift) . T.sequenceA $ k traversePure traverseFree
where
traversePure = return . fmap return . f
traverseFree xg = return . fmap (wrap . fmap (join . lift)) . T.traverse (T.sequenceA . xg)
instance (MonadIO m) => MonadIO (FT f m) where
liftIO = lift . liftIO
instance (Functor f, MonadError e m) => MonadError e (FT f m) where
throwError = lift . throwError
m `catchError` f = toFT $ fromFT m `catchError` (fromFT . f)
instance MonadCont m => MonadCont (FT f m) where
callCC f = join . lift $ callCC (\k -> return $ f (lift . k . return))
instance MonadReader r m => MonadReader r (FT f m) where
ask = lift ask
local f = hoistFT (local f)
instance (Functor f, MonadWriter w m) => MonadWriter w (FT f m) where
tell = lift . tell
listen = toFT . listen . fromFT
pass = toFT . pass . fromFT
#if MIN_VERSION_mtl(2,1,1)
writer w = lift (writer w)
#endif
instance MonadState s m => MonadState s (FT f m) where
get = lift get
put = lift . put
#if MIN_VERSION_mtl(2,1,1)
state f = lift (state f)
#endif
instance MonadThrow m => MonadThrow (FT f m) where
throwM = lift . throwM
instance (Functor f, MonadCatch m) => MonadCatch (FT f m) where
catch m f = toFT $ fromFT m `Control.Monad.Catch.catch` (fromFT . f)
toFT :: Monad m => FreeT f m a -> FT f m a
toFT (FreeT f) = FT $ \ka kfr -> do
freef <- f
case freef of
Pure a -> ka a
Free fb -> kfr (\x -> runFT (toFT x) ka kfr) fb
fromFT :: (Monad m, Functor f) => FT f m a -> FreeT f m a
fromFT (FT k) = FreeT $ k (return . Pure) (\xg -> runFreeT . wrap . fmap (FreeT . xg))
type F f = FT f Identity
runF :: Functor f => F f a -> (forall r. (a -> r) -> (f r -> r) -> r)
runF (FT m) = \kp kf -> runIdentity $ m (return . kp) (\xg -> return . kf . fmap (runIdentity . xg))
free :: (forall r. (a -> r) -> (f r -> r) -> r) -> F f a
free f = FT (\kp kf -> return $ f (runIdentity . kp) (runIdentity . kf return))
iterT :: (Functor f, Monad m) => (f (m a) -> m a) -> FT f m a -> m a
iterT phi (FT m) = m return (\xg -> phi . fmap xg)
iterTM :: (Functor f, Monad m, MonadTrans t, Monad (t m)) => (f (t m a) -> t m a) -> FT f m a -> t m a
iterTM f (FT m) = join . lift $ m (return . return) (\xg -> return . f . fmap (join . lift . xg))
hoistFT :: (Monad m, Monad n) => (forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT phi (FT m) = FT (\kp kf -> join . phi $ m (return . kp) (\xg -> return . kf (join . phi . xg)))
transFT :: (forall a. f a -> g a) -> FT f m b -> FT g m b
transFT phi (FT m) = FT (\kp kf -> m kp (\xg -> kf xg . phi))
joinFT :: (Monad m, Traversable f) => FT f m a -> m (F f a)
joinFT (FT m) = m (return . return) (\xg -> liftM wrap . T.mapM xg)
cutoff :: (Functor f, Monad m) => Integer -> FT f m a -> FT f m (Maybe a)
cutoff n = toFT . FreeT.cutoff n . fromFT
#if __GLASGOW_HASKELL__ < 710
retract :: (Functor f, Monad f) => F f a -> f a
#else
retract :: Monad f => F f a -> f a
#endif
retract m = runF m return join
retractT :: (MonadTrans t, Monad (t m), Monad m) => FT (t m) m a -> t m a
retractT (FT m) = join . lift $ m (return . return) (\xg xf -> return $ xf >>= join . lift . xg)
iter :: Functor f => (f a -> a) -> F f a -> a
iter phi = runIdentity . iterT (Identity . phi . fmap runIdentity)
iterM :: (Functor f, Monad m) => (f (m a) -> m a) -> F f a -> m a
iterM phi = iterT phi . hoistFT (return . runIdentity)
fromF :: (Functor f, MonadFree f m) => F f a -> m a
fromF m = runF m return wrap
toF :: Free f a -> F f a
toF = toFT
improve :: Functor f => (forall m. MonadFree f m => m a) -> Free f a
improve m = fromF m
improveT :: (Functor f, Monad m) => (forall t. MonadFree f (t m) => t m a) -> FreeT f m a
improveT m = fromFT m