#include "free-common.h"
module Control.Monad.Trans.Iter
(
IterT(..)
, Iter, iter, runIter
, delay
, hoistIterT
, liftIter
, cutoff
, never
, untilJust
, interleave, interleave_
, retract
, fold
, foldM
, MonadFree(..)
) where
import Control.Applicative
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad (ap, liftM, MonadPlus(..), join)
import Control.Monad.Fix
import Control.Monad.Trans.Class
import qualified Control.Monad.Fail as Fail
import Control.Monad.Free.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.Cont.Class
import Control.Monad.IO.Class
import Data.Bifunctor
import Data.Bitraversable
import Data.Either
import Data.Functor.Bind hiding (join)
import Data.Functor.Classes.Compat
import Data.Functor.Identity
import Data.Semigroup
import Data.Semigroup.Foldable
import Data.Semigroup.Traversable
import Data.Typeable
import Data.Data
#if !(MIN_VERSION_base(4,8,0))
import Data.Foldable hiding (fold)
import Data.Traversable hiding (mapM)
#endif
newtype IterT m a = IterT { runIterT :: m (Either a (IterT m a)) }
#if __GLASGOW_HASKELL__ >= 707
deriving (Typeable)
#endif
type Iter = IterT Identity
iter :: Either a (Iter a) -> Iter a
iter = IterT . Identity
runIter :: Iter a -> Either a (Iter a)
runIter = runIdentity . runIterT
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Eq1 m) => Eq1 (IterT m) where
liftEq eq = go
where
go (IterT x) (IterT y) = liftEq (liftEq2 eq go) x y
#else
instance (Functor m, Eq1 m) => Eq1 (IterT m) where
eq1 = on eq1 (fmap (fmap Lift1) . runIterT)
#endif
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Eq1 m, Eq a) => Eq (IterT m a) where
#else
instance (Functor m, Eq1 m, Eq a) => Eq (IterT m a) where
#endif
(==) = eq1
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Ord1 m) => Ord1 (IterT m) where
liftCompare cmp = go
where
go (IterT x) (IterT y) = liftCompare (liftCompare2 cmp go) x y
#else
instance (Functor m, Ord1 m) => Ord1 (IterT m) where
compare1 = on compare1 (fmap (fmap Lift1) . runIterT)
#endif
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Ord1 m, Ord a) => Ord (IterT m a) where
#else
instance (Functor m, Ord1 m, Ord a) => Ord (IterT m a) where
#endif
compare = compare1
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Show1 m) => Show1 (IterT m) where
liftShowsPrec sp sl = go
where
goList = liftShowList sp sl
go d (IterT x) = showsUnaryWith
(liftShowsPrec (liftShowsPrec2 sp sl go goList) (liftShowList2 sp sl go goList))
"IterT" d x
#else
instance (Functor m, Show1 m) => Show1 (IterT m) where
showsPrec1 d (IterT m) = showParen (d > 10) $
showString "IterT " . showsPrec1 11 (fmap (fmap Lift1) m)
#endif
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Show1 m, Show a) => Show (IterT m a) where
#else
instance (Functor m, Show1 m, Show a) => Show (IterT m a) where
#endif
showsPrec = showsPrec1
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Read1 m) => Read1 (IterT m) where
liftReadsPrec rp rl = go
where
goList = liftReadList rp rl
go = readsData $ readsUnaryWith
(liftReadsPrec (liftReadsPrec2 rp rl go goList) (liftReadList2 rp rl go goList))
"IterT" IterT
#else
instance (Functor m, Read1 m) => Read1 (IterT m) where
readsPrec1 d = readParen (d > 10) $ \r ->
[ (IterT (fmap (fmap lower1) m),t) | ("IterT",s) <- lex r, (m,t) <- readsPrec1 11 s]
#endif
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Read1 m, Read a) => Read (IterT m a) where
#else
instance (Functor m, Read1 m, Read a) => Read (IterT m a) where
#endif
readsPrec = readsPrec1
instance Monad m => Functor (IterT m) where
fmap f = IterT . liftM (bimap f (fmap f)) . runIterT
instance Monad m => Applicative (IterT m) where
pure = IterT . return . Left
(<*>) = ap
instance Monad m => Monad (IterT m) where
return = pure
IterT m >>= k = IterT $ m >>= either (runIterT . k) (return . Right . (>>= k))
fail = Fail.fail
instance Monad m => Fail.MonadFail (IterT m) where
fail _ = never
instance Monad m => Apply (IterT m) where
(<.>) = ap
instance Monad m => Bind (IterT m) where
(>>-) = (>>=)
instance MonadFix m => MonadFix (IterT m) where
mfix f = IterT $ mfix $ runIterT . f . either id (error "mfix (IterT m): Right")
instance Monad m => Alternative (IterT m) where
empty = mzero
(<|>) = mplus
instance Monad m => MonadPlus (IterT m) where
mzero = never
(IterT x) `mplus` (IterT y) = IterT $ x >>= either
(return . Left)
(flip liftM y . second . mplus)
instance MonadTrans IterT where
lift = IterT . liftM Left
instance Foldable m => Foldable (IterT m) where
foldMap f = foldMap (either f (foldMap f)) . runIterT
instance Foldable1 m => Foldable1 (IterT m) where
foldMap1 f = foldMap1 (either f (foldMap1 f)) . runIterT
instance (Monad m, Traversable m) => Traversable (IterT m) where
traverse f (IterT m) = IterT <$> traverse (bitraverse f (traverse f)) m
instance (Monad m, Traversable1 m) => Traversable1 (IterT m) where
traverse1 f (IterT m) = IterT <$> traverse1 go m where
go (Left a) = Left <$> f a
go (Right a) = Right <$> traverse1 f a
instance MonadReader e m => MonadReader e (IterT m) where
ask = lift ask
local f = hoistIterT (local f)
instance MonadWriter w m => MonadWriter w (IterT m) where
tell = lift . tell
listen (IterT m) = IterT $ liftM concat' $ listen (fmap listen `liftM` m)
where
concat' (Left x, w) = Left (x, w)
concat' (Right y, w) = Right $ second (w `mappend`) <$> y
pass m = IterT . pass' . runIterT . hoistIterT clean $ listen m
where
clean = pass . liftM (\x -> (x, const mempty))
pass' = join . liftM g
g (Left ((x, f), w)) = tell (f w) >> return (Left x)
g (Right f) = return . Right . IterT . pass' . runIterT $ f
#if MIN_VERSION_mtl(2,1,1)
writer w = lift (writer w)
#endif
instance MonadState s m => MonadState s (IterT m) where
get = lift get
put s = lift (put s)
#if MIN_VERSION_mtl(2,1,1)
state f = lift (state f)
#endif
instance MonadError e m => MonadError e (IterT m) where
throwError = lift . throwError
IterT m `catchError` f = IterT $ liftM (fmap (`catchError` f)) m `catchError` (runIterT . f)
instance MonadIO m => MonadIO (IterT m) where
liftIO = lift . liftIO
instance MonadCont m => MonadCont (IterT m) where
callCC f = IterT $ callCC (\k -> runIterT $ f (lift . k . Left))
instance Monad m => MonadFree Identity (IterT m) where
wrap = IterT . return . Right . runIdentity
instance MonadThrow m => MonadThrow (IterT m) where
throwM = lift . throwM
instance MonadCatch m => MonadCatch (IterT m) where
catch (IterT m) f = IterT $ liftM (fmap (`Control.Monad.Catch.catch` f)) m `Control.Monad.Catch.catch` (runIterT . f)
delay :: (Monad f, MonadFree f m) => m a -> m a
delay = wrap . return
retract :: Monad m => IterT m a -> m a
retract m = runIterT m >>= either return retract
fold :: Monad m => (m a -> a) -> IterT m a -> a
fold phi (IterT m) = phi (either id (fold phi) `liftM` m)
foldM :: (Monad m, Monad n) => (m (n a) -> n a) -> IterT m a -> n a
foldM phi (IterT m) = phi (either return (foldM phi) `liftM` m)
hoistIterT :: Monad n => (forall a. m a -> n a) -> IterT m b -> IterT n b
hoistIterT f (IterT as) = IterT (fmap (hoistIterT f) `liftM` f as)
liftIter :: (Monad m) => Iter a -> IterT m a
liftIter = hoistIterT (return . runIdentity)
never :: (Monad f, MonadFree f m) => m a
never = delay never
untilJust :: (Monad m) => m (Maybe a) -> IterT m a
untilJust f = maybe (delay (untilJust f)) return =<< lift f
cutoff :: (Monad m) => Integer -> IterT m a -> IterT m (Maybe a)
cutoff n | n <= 0 = const $ return Nothing
cutoff n = IterT . liftM (either (Left . Just)
(Right . cutoff (n 1))) . runIterT
interleave :: Monad m => [IterT m a] -> IterT m [a]
interleave ms = IterT $ do
xs <- mapM runIterT ms
if null (rights xs)
then return . Left $ lefts xs
else return . Right . interleave $ map (either return id) xs
interleave_ :: (Monad m) => [IterT m a] -> IterT m ()
interleave_ [] = return ()
interleave_ xs = IterT $ liftM (Right . interleave_ . rights) $ mapM runIterT xs
instance (Monad m, Semigroup a, Monoid a) => Monoid (IterT m a) where
mempty = return mempty
mappend = (<>)
mconcat = mconcat' . map Right
where
mconcat' :: (Monad m, Monoid a) => [Either a (IterT m a)] -> IterT m a
mconcat' ms = IterT $ do
xs <- mapM (either (return . Left) runIterT) ms
case compact xs of
[l@(Left _)] -> return l
xs' -> return . Right $ mconcat' xs'
compact :: (Monoid a) => [Either a b] -> [Either a b]
compact [] = []
compact (r@(Right _):xs) = r:(compact xs)
compact ( Left a :xs) = compact' a xs
compact' a [] = [Left a]
compact' a (r@(Right _):xs) = (Left a):(r:(compact xs))
compact' a ( (Left a'):xs) = compact' (a `mappend` a') xs
instance (Monad m, Semigroup a) => Semigroup (IterT m a) where
x <> y = IterT $ do
x' <- runIterT x
y' <- runIterT y
case (x', y') of
( Left a, Left b) -> return . Left $ a <> b
( Left a, Right b) -> return . Right $ liftM (a <>) b
(Right a, Left b) -> return . Right $ liftM (<> b) a
(Right a, Right b) -> return . Right $ a <> b
#if __GLASGOW_HASKELL__ < 707
instance Typeable1 m => Typeable1 (IterT m) where
typeOf1 t = mkTyConApp freeTyCon [typeOf1 (f t)] where
f :: IterT m a -> m a
f = undefined
freeTyCon :: TyCon
#if __GLASGOW_HASKELL__ < 704
freeTyCon = mkTyCon "Control.Monad.Iter.IterT"
#else
freeTyCon = mkTyCon3 "free" "Control.Monad.Iter" "IterT"
#endif
#else
#define Typeable1 Typeable
#endif
instance
( Typeable1 m, Typeable a
, Data (m (Either a (IterT m a)))
, Data a
) => Data (IterT m a) where
gfoldl f z (IterT as) = z IterT `f` as
toConstr IterT{} = iterConstr
gunfold k z c = case constrIndex c of
1 -> k (z IterT)
_ -> error "gunfold"
dataTypeOf _ = iterDataType
dataCast1 f = gcast1 f
iterConstr :: Constr
iterConstr = mkConstr iterDataType "IterT" [] Prefix
iterDataType :: DataType
iterDataType = mkDataType "Control.Monad.Iter.IterT" [iterConstr]