#ifndef MIN_VERSION_base
#define MIN_VERSION_base(x,y,z) 1
#endif
#ifndef MIN_VERSION_mtl
#define MIN_VERSION_mtl(x,y,z) 1
#endif
module Control.Monad.Trans.Iter
(
IterT(..)
, Iter, iter, runIter
, delay
, hoistIterT
, liftIter
, cutoff
, never
, untilJust
, interleave, interleave_
, retract
, fold
, foldM
, MonadFree(..)
) where
import Control.Applicative
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad (ap, liftM, MonadPlus(..), join)
import Control.Monad.Fix
import Control.Monad.Trans.Class
import Control.Monad.Free.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.Cont.Class
import Control.Monad.IO.Class
import Data.Bifunctor
import Data.Bitraversable
import Data.Either
import Data.Functor.Bind hiding (join)
import Data.Functor.Identity
import Data.Function (on)
import Data.Monoid
import Data.Semigroup.Foldable
import Data.Semigroup.Traversable
import Data.Typeable
import Data.Data
import Prelude.Extras
#if !(MIN_VERSION_base(4,8,0))
import Data.Foldable hiding (fold)
import Data.Traversable hiding (mapM)
#endif
newtype IterT m a = IterT { runIterT :: m (Either a (IterT m a)) }
#if __GLASGOW_HASKELL__ >= 707
deriving (Typeable)
#endif
type Iter = IterT Identity
iter :: Either a (Iter a) -> Iter a
iter = IterT . Identity
runIter :: Iter a -> Either a (Iter a)
runIter = runIdentity . runIterT
instance (Functor m, Eq1 m) => Eq1 (IterT m) where
(==#) = on (==#) (fmap (fmap Lift1) . runIterT)
instance Eq (m (Either a (IterT m a))) => Eq (IterT m a) where
IterT m == IterT n = m == n
instance (Functor m, Ord1 m) => Ord1 (IterT m) where
compare1 = on compare1 (fmap (fmap Lift1) . runIterT)
instance Ord (m (Either a (IterT m a))) => Ord (IterT m a) where
compare (IterT m) (IterT n) = compare m n
instance (Functor m, Show1 m) => Show1 (IterT m) where
showsPrec1 d (IterT m) = showParen (d > 10) $
showString "IterT " . showsPrec1 11 (fmap (fmap Lift1) m)
instance Show (m (Either a (IterT m a))) => Show (IterT m a) where
showsPrec d (IterT m) = showParen (d > 10) $
showString "IterT " . showsPrec 11 m
instance (Functor m, Read1 m) => Read1 (IterT m) where
readsPrec1 d = readParen (d > 10) $ \r ->
[ (IterT (fmap (fmap lower1) m),t) | ("IterT",s) <- lex r, (m,t) <- readsPrec1 11 s]
instance Read (m (Either a (IterT m a))) => Read (IterT m a) where
readsPrec d = readParen (d > 10) $ \r ->
[ (IterT m,t) | ("IterT",s) <- lex r, (m,t) <- readsPrec 11 s]
instance Monad m => Functor (IterT m) where
fmap f = IterT . liftM (bimap f (fmap f)) . runIterT
instance Monad m => Applicative (IterT m) where
pure = IterT . return . Left
(<*>) = ap
instance Monad m => Monad (IterT m) where
return = IterT . return . Left
IterT m >>= k = IterT $ m >>= either (runIterT . k) (return . Right . (>>= k))
fail _ = never
instance Monad m => Apply (IterT m) where
(<.>) = ap
instance Monad m => Bind (IterT m) where
(>>-) = (>>=)
instance MonadFix m => MonadFix (IterT m) where
mfix f = IterT $ mfix $ runIterT . f . either id (error "mfix (IterT m): Right")
instance Monad m => Alternative (IterT m) where
empty = mzero
(<|>) = mplus
instance Monad m => MonadPlus (IterT m) where
mzero = never
(IterT x) `mplus` (IterT y) = IterT $ x >>= either
(return . Left)
(flip liftM y . second . mplus)
instance MonadTrans IterT where
lift = IterT . liftM Left
instance Foldable m => Foldable (IterT m) where
foldMap f = foldMap (either f (foldMap f)) . runIterT
instance Foldable1 m => Foldable1 (IterT m) where
foldMap1 f = foldMap1 (either f (foldMap1 f)) . runIterT
instance (Monad m, Traversable m) => Traversable (IterT m) where
traverse f (IterT m) = IterT <$> traverse (bitraverse f (traverse f)) m
instance (Monad m, Traversable1 m) => Traversable1 (IterT m) where
traverse1 f (IterT m) = IterT <$> traverse1 go m where
go (Left a) = Left <$> f a
go (Right a) = Right <$> traverse1 f a
instance MonadReader e m => MonadReader e (IterT m) where
ask = lift ask
local f = hoistIterT (local f)
instance MonadWriter w m => MonadWriter w (IterT m) where
tell = lift . tell
listen (IterT m) = IterT $ liftM concat' $ listen (fmap listen `liftM` m)
where
concat' (Left x, w) = Left (x, w)
concat' (Right y, w) = Right $ second (w <>) <$> y
pass m = IterT . pass' . runIterT . hoistIterT clean $ listen m
where
clean = pass . liftM (\x -> (x, const mempty))
pass' = join . liftM g
g (Left ((x, f), w)) = tell (f w) >> return (Left x)
g (Right f) = return . Right . IterT . pass' . runIterT $ f
#if MIN_VERSION_mtl(2,1,1)
writer w = lift (writer w)
#endif
instance MonadState s m => MonadState s (IterT m) where
get = lift get
put s = lift (put s)
#if MIN_VERSION_mtl(2,1,1)
state f = lift (state f)
#endif
instance MonadError e m => MonadError e (IterT m) where
throwError = lift . throwError
IterT m `catchError` f = IterT $ liftM (fmap (`catchError` f)) m `catchError` (runIterT . f)
instance MonadIO m => MonadIO (IterT m) where
liftIO = lift . liftIO
instance MonadCont m => MonadCont (IterT m) where
callCC f = IterT $ callCC (\k -> runIterT $ f (lift . k . Left))
instance Monad m => MonadFree Identity (IterT m) where
wrap = IterT . return . Right . runIdentity
instance MonadThrow m => MonadThrow (IterT m) where
throwM = lift . throwM
instance MonadCatch m => MonadCatch (IterT m) where
catch (IterT m) f = IterT $ liftM (fmap (`Control.Monad.Catch.catch` f)) m `Control.Monad.Catch.catch` (runIterT . f)
delay :: (Monad f, MonadFree f m) => m a -> m a
delay = wrap . return
retract :: Monad m => IterT m a -> m a
retract m = runIterT m >>= either return retract
fold :: Monad m => (m a -> a) -> IterT m a -> a
fold phi (IterT m) = phi (either id (fold phi) `liftM` m)
foldM :: (Monad m, Monad n) => (m (n a) -> n a) -> IterT m a -> n a
foldM phi (IterT m) = phi (either return (foldM phi) `liftM` m)
hoistIterT :: Monad n => (forall a. m a -> n a) -> IterT m b -> IterT n b
hoistIterT f (IterT as) = IterT (fmap (hoistIterT f) `liftM` f as)
liftIter :: (Monad m) => Iter a -> IterT m a
liftIter = hoistIterT (return . runIdentity)
never :: (Monad f, MonadFree f m) => m a
never = delay never
untilJust :: (Monad m) => m (Maybe a) -> IterT m a
untilJust f = maybe (delay (untilJust f)) return =<< lift f
cutoff :: (Monad m) => Integer -> IterT m a -> IterT m (Maybe a)
cutoff n | n <= 0 = const $ return Nothing
cutoff n = IterT . liftM (either (Left . Just)
(Right . cutoff (n 1))) . runIterT
interleave :: Monad m => [IterT m a] -> IterT m [a]
interleave ms = IterT $ do
xs <- mapM runIterT ms
if null (rights xs)
then return . Left $ lefts xs
else return . Right . interleave $ map (either return id) xs
interleave_ :: (Monad m) => [IterT m a] -> IterT m ()
interleave_ [] = return ()
interleave_ xs = IterT $ liftM (Right . interleave_ . rights) $ mapM runIterT xs
instance (Monad m, Monoid a) => Monoid (IterT m a) where
mempty = return mempty
x `mappend` y = IterT $ do
x' <- runIterT x
y' <- runIterT y
case (x', y') of
( Left a, Left b) -> return . Left $ a `mappend` b
( Left a, Right b) -> return . Right $ liftM (a `mappend`) b
(Right a, Left b) -> return . Right $ liftM (`mappend` b) a
(Right a, Right b) -> return . Right $ a `mappend` b
mconcat = mconcat' . map Right
where
mconcat' :: (Monad m, Monoid a) => [Either a (IterT m a)] -> IterT m a
mconcat' ms = IterT $ do
xs <- mapM (either (return . Left) runIterT) ms
case compact xs of
[l@(Left _)] -> return l
xs' -> return . Right $ mconcat' xs'
compact :: (Monoid a) => [Either a b] -> [Either a b]
compact [] = []
compact (r@(Right _):xs) = r:(compact xs)
compact ( Left a :xs) = compact' a xs
compact' a [] = [Left a]
compact' a (r@(Right _):xs) = (Left a):(r:(compact xs))
compact' a ( (Left a'):xs) = compact' (a <> a') xs
#if __GLASGOW_HASKELL__ < 707
instance Typeable1 m => Typeable1 (IterT m) where
typeOf1 t = mkTyConApp freeTyCon [typeOf1 (f t)] where
f :: IterT m a -> m a
f = undefined
freeTyCon :: TyCon
#if __GLASGOW_HASKELL__ < 704
freeTyCon = mkTyCon "Control.Monad.Iter.IterT"
#else
freeTyCon = mkTyCon3 "free" "Control.Monad.Iter" "IterT"
#endif
#else
#define Typeable1 Typeable
#endif
instance
( Typeable1 m, Typeable a
, Data (m (Either a (IterT m a)))
, Data a
) => Data (IterT m a) where
gfoldl f z (IterT as) = z IterT `f` as
toConstr IterT{} = iterConstr
gunfold k z c = case constrIndex c of
1 -> k (z IterT)
_ -> error "gunfold"
dataTypeOf _ = iterDataType
dataCast1 f = gcast1 f
iterConstr :: Constr
iterConstr = mkConstr iterDataType "IterT" [] Prefix
iterDataType :: DataType
iterDataType = mkDataType "Control.Monad.Iter.IterT" [iterConstr]