{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
module Control.Monad.Search
(
Search
, runSearch
, runSearchBest
, SearchT
, runSearchT
, runSearchBestT
, MonadSearch
, cost
, cost'
, junction
, abandon
, seal
, collapse
, winner
) where
import Control.Applicative ( Alternative(..) )
import Control.Monad ( MonadPlus(..) )
import Control.Monad.Trans.Free ( FreeF(Free, Pure), FreeT
, runFreeT, wrap )
import Control.Monad.Trans.State ( evalStateT, gets, modify )
import Control.Monad.Trans.Class ( MonadTrans, lift )
import Control.Monad.IO.Class ( MonadIO )
import Control.Monad.Reader ( MonadReader, ReaderT(..)
, runReaderT )
import qualified Control.Monad.Writer.Lazy as Lazy ( MonadWriter, WriterT(..)
, runWriterT )
import qualified Control.Monad.Writer.Strict as Strict ( WriterT(..)
, runWriterT )
import qualified Control.Monad.State.Lazy as Lazy ( MonadState, StateT(..)
, runStateT )
import qualified Control.Monad.State.Strict as Strict ( StateT(..), runStateT )
import qualified Control.Monad.RWS.Lazy as Lazy ( MonadRWS, RWST(..)
, runRWST )
import qualified Control.Monad.RWS.Strict as Strict ( RWST(..), runRWST )
import Control.Monad.Except ( ExceptT(..), MonadError
, runExceptT )
import Control.Monad.Cont ( MonadCont )
import Data.Functor.Identity ( Identity, runIdentity )
import Data.Maybe ( catMaybes, listToMaybe )
import qualified Data.OrdPSQ as PSQ
type Search c = SearchT c Identity
runSearch :: (Ord c, Monoid c) => Search c a -> [(c, a)]
runSearch = runIdentity . runSearchT
runSearchBest :: (Ord c, Monoid c) => Search c a -> Maybe (c, a)
runSearchBest = runIdentity . runSearchBestT
data SearchF c a = Cost c c a
| Alt a a
| Enter a
| Exit a
| Collapse a
| Abandon
deriving Functor
newtype SearchT c m a = SearchT { unSearchT :: FreeT (SearchF c) m a }
deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadReader r, Lazy.MonadWriter w, Lazy.MonadState s, MonadError e, MonadCont)
instance (Ord c, Monoid c, Monad m) => Alternative (SearchT c m) where
empty = abandon
(<|>) = junction
instance (Ord c, Monoid c, Monad m) => MonadPlus (SearchT c m)
deriving instance Lazy.MonadRWS r w s m => Lazy.MonadRWS r w s (SearchT c m)
data Cand c m a = Cand { candCost :: !c
, candScope :: ![Int]
, candPath :: FreeT (SearchF c) m a
}
data St c m a = St { stNum :: !Int
, stScope :: !Int
, stQueue :: !(PSQ.OrdPSQ Int c (Cand c m a))
}
runSearchT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m [(c, a)]
runSearchT m = catMaybes <$> evalStateT go state
where
go = do
mmin <- gets (PSQ.minView . stQueue)
case mmin of
Nothing -> return []
Just (num, prio, cand, q) -> do
updateQueue $ const q
(:) <$> step num prio cand <*> go
step num prio cand@Cand{..} = do
path' <- lift $ runFreeT candPath
case path' of
Pure a -> return $ Just (candCost, a)
Free Abandon -> return Nothing
Free (Cost c e p) ->
let newCost = candCost `mappend` c
newPriority = newCost `mappend` e
in do
updateQueue $
PSQ.insert num
newPriority
cand { candCost = newCost, candPath = p }
return Nothing
Free (Alt lhs rhs) -> do
num' <- nextNum
updateQueue $ PSQ.insert num' prio cand { candPath = rhs }
step num prio cand { candPath = lhs }
Free (Enter p) -> do
scope <- nextScope
step num
prio
cand { candScope = scope : candScope, candPath = p }
Free (Exit p) ->
step num prio cand { candScope = tail candScope, candPath = p }
Free (Collapse p) -> do
updateQueue $ PSQ.fromList .
filter (\(_, _, c) -> not $ hasScope (head candScope) c) .
PSQ.toList
step num prio cand { candPath = p }
nextNum = do
modify $ \s -> s { stNum = stNum s + 1 }
gets stNum
nextScope = do
modify $ \s -> s { stScope = stScope s + 1 }
gets stScope
hasScope s Cand{..} = s `elem` candScope
updateQueue f = modify $ \s -> s { stQueue = f (stQueue s) }
state = St 0 0 queue
queue = PSQ.singleton 0 mempty (Cand mempty [ 0 ] (unSearchT m))
runSearchBestT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m (Maybe (c, a))
runSearchBestT m = listToMaybe <$> runSearchT (m <* collapse)
class (Ord c, Monoid c, Monad m) => MonadSearch c m | m -> c where
cost :: c -> c -> m ()
junction :: m a -> m a -> m a
abandon :: m a
seal :: m a -> m a
collapse :: m ()
instance (Ord c, Monoid c, Monad m) => MonadSearch c (SearchT c m) where
cost c e = SearchT . wrap $ Cost c e (return ())
junction lhs rhs = SearchT . wrap $ Alt (unSearchT lhs) (unSearchT rhs)
abandon = SearchT . wrap $ Abandon
seal m = SearchT . wrap $ Enter (unSearchT m >>= wrap . Exit . return)
collapse = SearchT . wrap $ Collapse (return ())
instance MonadSearch c m => MonadSearch c (ReaderT r m) where
cost c e = lift $ cost c e
junction lhs rhs = ReaderT $
\r -> junction (runReaderT lhs r) (runReaderT rhs r)
abandon = lift abandon
seal m = ReaderT $ \r -> seal (runReaderT m r)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.WriterT w m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.WriterT $
junction (Lazy.runWriterT lhs) (Lazy.runWriterT rhs)
abandon = lift abandon
seal m = Lazy.WriterT $ seal (Lazy.runWriterT m)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.WriterT w m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.WriterT $
junction (Strict.runWriterT lhs) (Strict.runWriterT rhs)
abandon = lift abandon
seal m = Strict.WriterT $ seal (Strict.runWriterT m)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (Lazy.StateT s m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.StateT $
\s -> junction (Lazy.runStateT lhs s) (Lazy.runStateT rhs s)
abandon = lift abandon
seal m = Lazy.StateT $ \s -> seal (Lazy.runStateT m s)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (Strict.StateT s m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.StateT $
\s -> junction (Strict.runStateT lhs s) (Strict.runStateT rhs s)
abandon = lift abandon
seal m = Strict.StateT $ \s -> seal (Strict.runStateT m s)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.RWST r w s m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.RWST $
\r s -> junction (Lazy.runRWST lhs r s) (Lazy.runRWST rhs r s)
abandon = lift abandon
seal m = Lazy.RWST $ \r s -> seal (Lazy.runRWST m r s)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.RWST r w s m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.RWST $
\r s -> junction (Strict.runRWST lhs r s) (Strict.runRWST rhs r s)
abandon = lift abandon
seal m = Strict.RWST $ \r s -> seal (Strict.runRWST m r s)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (ExceptT e m) where
cost c e = lift $ cost c e
junction lhs rhs = ExceptT $ junction (runExceptT lhs) (runExceptT rhs)
abandon = lift abandon
seal m = ExceptT $ seal (runExceptT m)
collapse = lift collapse
cost' :: MonadSearch c m => c -> m ()
cost' c = cost c mempty
winner :: MonadSearch c m => m a -> m a
winner m = seal $ m <* collapse