{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# language Safe #-}
-- |
-- Module       : Control.Monad.Trans.Smash
-- Copyright    : (c) 2020-2022 Emily Pillmore
-- License      : BSD-3-Clause
--
-- Maintainer   : Emily Pillmore <emilypi@cohomolo.gy>
-- Stability    : Experimental
-- Portability  : Non-portable
--
-- This module contains utilities for the monad transformer
-- for the smash product.
--
module Control.Monad.Trans.Smash
( -- * Monad transformer
  SmashT(..)
  -- ** Combinators
, mapSmashT
) where


import Data.Smash

import Control.Applicative (liftA2)
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.RWS


-- | A monad transformer for the smash product,
-- parameterized by:
--
--   * @a@ - the value on the left
--   * @b@ - the value on the right
--   * @m@ - The monad over a pointed product (see: 'Smash').
--
newtype SmashT a m b = SmashT { SmashT a m b -> m (Smash a b)
runSmashT :: m (Smash a b) }

-- | Map both the left and right values and output of a computation using
-- the given function.
--
-- * @'runSmashT' ('mapSmashT' f m) = f . 'runSmashT' m@
--
mapSmashT :: (m (Smash a b) -> n (Smash c d)) -> SmashT a m b -> SmashT c n d
mapSmashT :: (m (Smash a b) -> n (Smash c d)) -> SmashT a m b -> SmashT c n d
mapSmashT m (Smash a b) -> n (Smash c d)
f = n (Smash c d) -> SmashT c n d
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (n (Smash c d) -> SmashT c n d)
-> (SmashT a m b -> n (Smash c d)) -> SmashT a m b -> SmashT c n d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Smash a b) -> n (Smash c d)
f (m (Smash a b) -> n (Smash c d))
-> (SmashT a m b -> m (Smash a b)) -> SmashT a m b -> n (Smash c d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SmashT a m b -> m (Smash a b)
forall a (m :: * -> *) b. SmashT a m b -> m (Smash a b)
runSmashT

instance Functor f => Functor (SmashT a f) where
  fmap :: (a -> b) -> SmashT a f a -> SmashT a f b
fmap a -> b
f = f (Smash a b) -> SmashT a f b
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (f (Smash a b) -> SmashT a f b)
-> (SmashT a f a -> f (Smash a b)) -> SmashT a f a -> SmashT a f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Smash a a -> Smash a b) -> f (Smash a a) -> f (Smash a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> Smash a a -> Smash a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) (f (Smash a a) -> f (Smash a b))
-> (SmashT a f a -> f (Smash a a)) -> SmashT a f a -> f (Smash a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SmashT a f a -> f (Smash a a)
forall a (m :: * -> *) b. SmashT a m b -> m (Smash a b)
runSmashT

instance (Monoid a, Applicative f) => Applicative (SmashT a f) where
  pure :: a -> SmashT a f a
pure = f (Smash a a) -> SmashT a f a
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (f (Smash a a) -> SmashT a f a)
-> (a -> f (Smash a a)) -> a -> SmashT a f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Smash a a -> f (Smash a a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Smash a a -> f (Smash a a))
-> (a -> Smash a a) -> a -> f (Smash a a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Smash a a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  SmashT f (Smash a (a -> b))
f <*> :: SmashT a f (a -> b) -> SmashT a f a -> SmashT a f b
<*> SmashT f (Smash a a)
a = f (Smash a b) -> SmashT a f b
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (f (Smash a b) -> SmashT a f b) -> f (Smash a b) -> SmashT a f b
forall a b. (a -> b) -> a -> b
$ (Smash a (a -> b) -> Smash a a -> Smash a b)
-> f (Smash a (a -> b)) -> f (Smash a a) -> f (Smash a b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Smash a (a -> b) -> Smash a a -> Smash a b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) f (Smash a (a -> b))
f f (Smash a a)
a

instance (Monoid a, Monad m) => Monad (SmashT a m) where
  return :: a -> SmashT a m a
return = a -> SmashT a m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  SmashT m (Smash a a)
m >>= :: SmashT a m a -> (a -> SmashT a m b) -> SmashT a m b
>>= a -> SmashT a m b
k = m (Smash a b) -> SmashT a m b
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (m (Smash a b) -> SmashT a m b) -> m (Smash a b) -> SmashT a m b
forall a b. (a -> b) -> a -> b
$ do
    Smash a a
c <- m (Smash a a)
m
    case Smash a a
c of
      Smash a
a a
b -> do
        Smash a b
c' <- SmashT a m b -> m (Smash a b)
forall a (m :: * -> *) b. SmashT a m b -> m (Smash a b)
runSmashT (SmashT a m b -> m (Smash a b)) -> SmashT a m b -> m (Smash a b)
forall a b. (a -> b) -> a -> b
$ a -> SmashT a m b
k a
b
        Smash a b -> m (Smash a b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Smash a b -> m (Smash a b)) -> Smash a b -> m (Smash a b)
forall a b. (a -> b) -> a -> b
$ case Smash a b
c' of
          Smash a b
Nada -> Smash a b
forall a b. Smash a b
Nada
          Smash a
a' b
b' -> a -> b -> Smash a b
forall a b. a -> b -> Smash a b
Smash (a
a a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
a') b
b'
      Smash a a
Nada -> Smash a b -> m (Smash a b)
forall (m :: * -> *) a. Monad m => a -> m a
return Smash a b
forall a b. Smash a b
Nada

instance (Monoid a, MonadReader r m) => MonadReader r (SmashT a m) where
  ask :: SmashT a m r
ask = m r -> SmashT a m r
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m r
forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: (r -> r) -> SmashT a m a -> SmashT a m a
local r -> r
f (SmashT m (Smash a a)
m) = m (Smash a a) -> SmashT a m a
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (m (Smash a a) -> SmashT a m a) -> m (Smash a a) -> SmashT a m a
forall a b. (a -> b) -> a -> b
$ (r -> r) -> m (Smash a a) -> m (Smash a a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f m (Smash a a)
m

instance (Monoid a, MonadWriter w m) => MonadWriter w (SmashT a m) where
  tell :: w -> SmashT a m ()
tell = m () -> SmashT a m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SmashT a m ()) -> (w -> m ()) -> w -> SmashT a m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. w -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell

  listen :: SmashT a m a -> SmashT a m (a, w)
listen (SmashT m (Smash a a)
m) = m (Smash a (a, w)) -> SmashT a m (a, w)
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (m (Smash a (a, w)) -> SmashT a m (a, w))
-> m (Smash a (a, w)) -> SmashT a m (a, w)
forall a b. (a -> b) -> a -> b
$ (Smash a a, w) -> Smash a (a, w)
forall a a b. (Smash a a, b) -> Smash a (a, b)
go ((Smash a a, w) -> Smash a (a, w))
-> m (Smash a a, w) -> m (Smash a (a, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Smash a a) -> m (Smash a a, w)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (Smash a a)
m where
    go :: (Smash a a, b) -> Smash a (a, b)
go (Smash a a
c,b
w) = case Smash a a
c of
      Smash a a
Nada -> Smash a (a, b)
forall a b. Smash a b
Nada
      Smash a
a a
b -> a -> (a, b) -> Smash a (a, b)
forall a b. a -> b -> Smash a b
Smash a
a (a
b, b
w)

  pass :: SmashT a m (a, w -> w) -> SmashT a m a
pass (SmashT m (Smash a (a, w -> w))
m) = m (Smash a a) -> SmashT a m a
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (m (Smash a a) -> SmashT a m a) -> m (Smash a a) -> SmashT a m a
forall a b. (a -> b) -> a -> b
$ m (Smash a a, w -> w) -> m (Smash a a)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (Smash a (a, w -> w) -> (Smash a a, w -> w)
forall a b a. Smash a (b, a -> a) -> (Smash a b, a -> a)
go (Smash a (a, w -> w) -> (Smash a a, w -> w))
-> m (Smash a (a, w -> w)) -> m (Smash a a, w -> w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Smash a (a, w -> w))
m) where
    go :: Smash a (b, a -> a) -> (Smash a b, a -> a)
go = \case
      Smash a (b, a -> a)
Nada -> (Smash a b
forall a b. Smash a b
Nada, a -> a
forall a. a -> a
id)
      Smash a
t (b
a, a -> a
f) -> (a -> b -> Smash a b
forall a b. a -> b -> Smash a b
Smash a
t b
a, a -> a
f)

instance (Monoid t, MonadState s m) => MonadState s (SmashT t m) where
  get :: SmashT t m s
get = m s -> SmashT t m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> SmashT t m ()
put = m () -> SmashT t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SmashT t m ()) -> (s -> m ()) -> s -> SmashT t m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance (Monoid t, MonadRWS r w s m) => MonadRWS r w s (SmashT t m)

instance Monoid a => MonadTrans (SmashT a) where
  lift :: m a -> SmashT a m a
lift = m (Smash a a) -> SmashT a m a
forall a (m :: * -> *) b. m (Smash a b) -> SmashT a m b
SmashT (m (Smash a a) -> SmashT a m a)
-> (m a -> m (Smash a a)) -> m a -> SmashT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Smash a a) -> m a -> m (Smash a a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> Smash a a
forall a b. a -> b -> Smash a b
Smash a
forall a. Monoid a => a
mempty)