{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Safe #-}
{-# LANGUAGE DeriveGeneric #-}
#endif
#if __GLASGOW_HASKELL__ >= 706
{-# LANGUAGE PolyKinds #-}
#endif
#if __GLASGOW_HASKELL__ >= 710 && __GLASGOW_HASKELL__ < 802
{-# LANGUAGE AutoDeriveTypeable #-}
#endif
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.Trans.Select
-- Copyright   :  (c) Ross Paterson 2017
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  R.Paterson@city.ac.uk
-- Stability   :  experimental
-- Portability :  portable
--
-- Selection monad transformer, modelling search algorithms.
--
-- * Martin Escardo and Paulo Oliva.
--   "Selection functions, bar recursion and backward induction",
--   /Mathematical Structures in Computer Science/ 20:2 (2010), pp. 127-168.
--   <https://www.cs.bham.ac.uk/~mhe/papers/selection-escardo-oliva.pdf>
--
-- * Jules Hedges. "Monad transformers for backtracking search".
--   In /Proceedings of MSFP 2014/. <https://arxiv.org/abs/1406.2058>
-----------------------------------------------------------------------------

module Control.Monad.Trans.Select (
    -- * The Select monad
    Select,
    select,
    runSelect,
    mapSelect,
    -- * The SelectT monad transformer
    SelectT(SelectT),
    runSelectT,
    mapSelectT,
    -- * Monad transformation
    selectToContT,
    ) where

import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Cont

import Control.Applicative
import Control.Monad
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
import Data.Functor.Identity
#if __GLASGOW_HASKELL__ >= 704
import GHC.Generics
#endif

-- | Selection monad.
type Select r = SelectT r Identity

-- | Constructor for computations in the selection monad.
select :: ((a -> r) -> a) -> Select r a
select :: forall a r. ((a -> r) -> a) -> Select r a
select (a -> r) -> a
f = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall a b. (a -> b) -> a -> b
$ \ a -> Identity r
k -> forall a. a -> Identity a
Identity ((a -> r) -> a
f (forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Identity r
k))
{-# INLINE select #-}

-- | Runs a @Select@ computation with a function for evaluating answers
-- to select a particular answer.  (The inverse of 'select'.)
runSelect :: Select r a -> (a -> r) -> a
runSelect :: forall r a. Select r a -> (a -> r) -> a
runSelect Select r a
m a -> r
k = forall a. Identity a -> a
runIdentity (forall r (m :: * -> *) a. SelectT r m a -> (a -> m r) -> m a
runSelectT Select r a
m (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> r
k))
{-# INLINE runSelect #-}

-- | Apply a function to transform the result of a selection computation.
--
-- * @'runSelect' ('mapSelect' f m) = f . 'runSelect' m@
mapSelect :: (a -> a) -> Select r a -> Select r a
mapSelect :: forall a r. (a -> a) -> Select r a -> Select r a
mapSelect a -> a
f = forall (m :: * -> *) a r.
(m a -> m a) -> SelectT r m a -> SelectT r m a
mapSelectT (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity)
{-# INLINE mapSelect #-}

-- | Selection monad transformer.
--
-- 'SelectT' is not a functor on the category of monads, and many operations
-- cannot be lifted through it.
newtype SelectT r m a = SelectT ((a -> m r) -> m a)
#if __GLASGOW_HASKELL__ >= 704
    deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall r (m :: * -> *) a x. Rep (SelectT r m a) x -> SelectT r m a
forall r (m :: * -> *) a x. SelectT r m a -> Rep (SelectT r m a) x
$cto :: forall r (m :: * -> *) a x. Rep (SelectT r m a) x -> SelectT r m a
$cfrom :: forall r (m :: * -> *) a x. SelectT r m a -> Rep (SelectT r m a) x
Generic)
#endif

-- | Runs a @SelectT@ computation with a function for evaluating answers
-- to select a particular answer.  (The inverse of 'select'.)
runSelectT :: SelectT r m a -> (a -> m r) -> m a
runSelectT :: forall r (m :: * -> *) a. SelectT r m a -> (a -> m r) -> m a
runSelectT (SelectT (a -> m r) -> m a
g) = (a -> m r) -> m a
g
{-# INLINE runSelectT #-}

-- | Apply a function to transform the result of a selection computation.
-- This has a more restricted type than the @map@ operations for other
-- monad transformers, because 'SelectT' does not define a functor in
-- the category of monads.
--
-- * @'runSelectT' ('mapSelectT' f m) = f . 'runSelectT' m@
mapSelectT :: (m a -> m a) -> SelectT r m a -> SelectT r m a
mapSelectT :: forall (m :: * -> *) a r.
(m a -> m a) -> SelectT r m a -> SelectT r m a
mapSelectT m a -> m a
f SelectT r m a
m = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall a b. (a -> b) -> a -> b
$ m a -> m a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r (m :: * -> *) a. SelectT r m a -> (a -> m r) -> m a
runSelectT SelectT r m a
m
{-# INLINE mapSelectT #-}

instance (Functor m) => Functor (SelectT r m) where
    fmap :: forall a b. (a -> b) -> SelectT r m a -> SelectT r m b
fmap a -> b
f (SelectT (a -> m r) -> m a
g) = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m r) -> m a
g forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f))
    {-# INLINE fmap #-}

instance (Functor m, Monad m) => Applicative (SelectT r m) where
    pure :: forall a. a -> SelectT r m a
pure = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return
    {-# INLINE pure #-}
    SelectT ((a -> b) -> m r) -> m (a -> b)
gf <*> :: forall a b. SelectT r m (a -> b) -> SelectT r m a -> SelectT r m b
<*> SelectT (a -> m r) -> m a
gx = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall a b. (a -> b) -> a -> b
$ \ b -> m r
k -> do
        let h :: (a -> b) -> m b
h a -> b
f = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a -> b
f ((a -> m r) -> m a
gx (b -> m r
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f))
        a -> b
f <- ((a -> b) -> m r) -> m (a -> b)
gf ((forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= b -> m r
k) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b) -> m b
h)
        (a -> b) -> m b
h a -> b
f
    {-# INLINE (<*>) #-}
    SelectT r m a
m *> :: forall a b. SelectT r m a -> SelectT r m b -> SelectT r m b
*> SelectT r m b
k = SelectT r m a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
_ -> SelectT r m b
k
    {-# INLINE (*>) #-}

instance (Functor m, MonadPlus m) => Alternative (SelectT r m) where
    empty :: forall a. SelectT r m a
empty = forall (m :: * -> *) a. MonadPlus m => m a
mzero
    {-# INLINE empty #-}
    <|> :: forall a. SelectT r m a -> SelectT r m a -> SelectT r m a
(<|>) = forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus
    {-# INLINE (<|>) #-}

instance (Monad m) => Monad (SelectT r m) where
#if !(MIN_VERSION_base(4,8,0))
    return = lift . return
    {-# INLINE return #-}
#endif
    SelectT (a -> m r) -> m a
g >>= :: forall a b. SelectT r m a -> (a -> SelectT r m b) -> SelectT r m b
>>= a -> SelectT r m b
f = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall a b. (a -> b) -> a -> b
$ \ b -> m r
k -> do
        let h :: a -> m b
h a
x = forall r (m :: * -> *) a. SelectT r m a -> (a -> m r) -> m a
runSelectT (a -> SelectT r m b
f a
x) b -> m r
k
        a
y <- (a -> m r) -> m a
g ((forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= b -> m r
k) forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
h)
        a -> m b
h a
y
    {-# INLINE (>>=) #-}

#if MIN_VERSION_base(4,9,0)
instance (Fail.MonadFail m) => Fail.MonadFail (SelectT r m) where
    fail :: forall a. String -> SelectT r m a
fail String
msg = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
msg)
    {-# INLINE fail #-}
#endif

instance (MonadPlus m) => MonadPlus (SelectT r m) where
    mzero :: forall a. SelectT r m a
mzero = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT (forall a b. a -> b -> a
const forall (m :: * -> *) a. MonadPlus m => m a
mzero)
    {-# INLINE mzero #-}
    SelectT (a -> m r) -> m a
f mplus :: forall a. SelectT r m a -> SelectT r m a -> SelectT r m a
`mplus` SelectT (a -> m r) -> m a
g = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall a b. (a -> b) -> a -> b
$ \ a -> m r
k -> (a -> m r) -> m a
f a -> m r
k forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` (a -> m r) -> m a
g a -> m r
k
    {-# INLINE mplus #-}

instance MonadTrans (SelectT r) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> SelectT r m a
lift = forall r (m :: * -> *) a. ((a -> m r) -> m a) -> SelectT r m a
SelectT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const
    {-# INLINE lift #-}

instance (MonadIO m) => MonadIO (SelectT r m) where
    liftIO :: forall a. IO a -> SelectT r m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
    {-# INLINE liftIO #-}

-- | Convert a selection computation to a continuation-passing computation.
selectToContT :: (Monad m) => SelectT r m a -> ContT r m a
selectToContT :: forall (m :: * -> *) r a. Monad m => SelectT r m a -> ContT r m a
selectToContT (SelectT (a -> m r) -> m a
g) = forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ \ a -> m r
k -> (a -> m r) -> m a
g a -> m r
k forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> m r
k
{-# INLINE selectToContT #-}