{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.Monad.Random.Class (
MonadRandom(..),
MonadSplit(..),
MonadInterleave(..),
fromList,
fromListMay,
uniform,
uniformMay,
weighted,
weightedMay
) where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Error
import Control.Monad.Trans.Except
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict as StrictRWS
import qualified Control.Monad.Trans.State.Lazy as LazyState
import qualified Control.Monad.Trans.State.Strict as StrictState
import qualified Control.Monad.Trans.Writer.Lazy as LazyWriter
import qualified Control.Monad.Trans.Writer.Strict as StrictWriter
import System.Random
import qualified Data.Foldable as F
#if MIN_VERSION_base(4,8,0)
#else
import Data.Monoid (Monoid)
#endif
class (Monad m) => MonadRandom m where
getRandomR :: (Random a) => (a, a) -> m a
getRandom :: (Random a) => m a
getRandomRs :: (Random a) => (a, a) -> m [a]
getRandoms :: (Random a) => m [a]
instance MonadRandom IO where
getRandomR = randomRIO
getRandom = randomIO
getRandomRs lohi = liftM (randomRs lohi) newStdGen
getRandoms = liftM randoms newStdGen
instance (MonadRandom m) => MonadRandom (ContT r m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (Error e, MonadRandom m) => MonadRandom (ErrorT e m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (ExceptT e m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (IdentityT m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (ListT m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (MaybeT m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (Monoid w, MonadRandom m) => MonadRandom (LazyRWS.RWST r w s m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (Monoid w, MonadRandom m) => MonadRandom (StrictRWS.RWST r w s m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (ReaderT r m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (LazyState.StateT s m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m) => MonadRandom (StrictState.StateT s m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m, Monoid w) => MonadRandom (LazyWriter.WriterT w m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
instance (MonadRandom m, Monoid w) => MonadRandom (StrictWriter.WriterT w m) where
getRandomR = lift . getRandomR
getRandom = lift getRandom
getRandomRs = lift . getRandomRs
getRandoms = lift getRandoms
class (Monad m) => MonadSplit g m | m -> g where
getSplit :: m g
instance MonadSplit StdGen IO where
getSplit = newStdGen
instance (MonadSplit g m) => MonadSplit g (ContT r m) where
getSplit = lift getSplit
instance (Error e, MonadSplit g m) => MonadSplit g (ErrorT e m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ExceptT e m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (IdentityT m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ListT m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (MaybeT m) where
getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (LazyRWS.RWST r w s m) where
getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (StrictRWS.RWST r w s m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ReaderT r m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (LazyState.StateT s m) where
getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (StrictState.StateT s m) where
getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (LazyWriter.WriterT w m) where
getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (StrictWriter.WriterT w m) where
getSplit = lift getSplit
class MonadRandom m => MonadInterleave m where
interleave :: m a -> m a
instance (MonadInterleave m) => MonadInterleave (ContT r m) where
interleave = mapContT interleave
instance (Error e, MonadInterleave m) => MonadInterleave (ErrorT e m) where
interleave = mapErrorT interleave
instance (MonadInterleave m) => MonadInterleave (ExceptT e m) where
interleave = mapExceptT interleave
instance (MonadInterleave m) => MonadInterleave (IdentityT m) where
interleave = mapIdentityT interleave
instance (MonadInterleave m) => MonadInterleave (ListT m) where
interleave = mapListT interleave
instance (MonadInterleave m) => MonadInterleave (MaybeT m) where
interleave = mapMaybeT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (LazyRWS.RWST r w s m) where
interleave = LazyRWS.mapRWST interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (StrictRWS.RWST r w s m) where
interleave = StrictRWS.mapRWST interleave
instance (MonadInterleave m) => MonadInterleave (ReaderT r m) where
interleave = mapReaderT interleave
instance (MonadInterleave m) => MonadInterleave (LazyState.StateT s m) where
interleave = LazyState.mapStateT interleave
instance (MonadInterleave m) => MonadInterleave (StrictState.StateT s m) where
interleave = StrictState.mapStateT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (LazyWriter.WriterT w m) where
interleave = LazyWriter.mapWriterT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (StrictWriter.WriterT w m) where
interleave = StrictWriter.mapWriterT interleave
weighted :: (F.Foldable t, MonadRandom m) => t (a, Rational) -> m a
weighted t = do
ma <- weightedMay t
case ma of
Nothing -> error "Control.Monad.Random.Class.weighted: empty collection, or total weight = 0"
Just a -> return a
weightedMay :: (F.Foldable t, MonadRandom m) => t (a, Rational) -> m (Maybe a)
weightedMay = fromListMay . F.toList
fromList :: (MonadRandom m) => [(a, Rational)] -> m a
fromList ws = do
ma <- fromListMay ws
case ma of
Nothing -> error "Control.Monad.Random.Class.fromList: empty list, or total weight = 0"
Just a -> return a
fromListMay :: (MonadRandom m) => [(a, Rational)] -> m (Maybe a)
fromListMay xs = do
let s = fromRational (sum (map snd xs)) :: Double
cums = scanl1 (\ ~(_,q) ~(y,s') -> (y, s'+q)) xs
case s of
0 -> return Nothing
_ -> do
p <- liftM toRational $ getRandomR (0, s)
return . Just . fst . head . dropWhile ((< p) . snd) $ cums
uniform :: (F.Foldable t, MonadRandom m) => t a -> m a
uniform t = do
ma <- uniformMay t
case ma of
Nothing -> error "Control.Monad.Random.Class.uniform: empty collection"
Just a -> return a
uniformMay :: (F.Foldable t, MonadRandom m) => t a -> m (Maybe a)
uniformMay = fromListMay . map (flip (,) 1) . F.toList