{-
 -      ``Data/Random/RVar''
 -}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |Random variables.  An 'RVar' is a sampleable random variable.  Because
-- probability distributions form a monad, they are quite easy to work with
-- in the standard Haskell monadic styles.  For examples, see the source for
-- any of the 'Distribution' instances - they all are defined in terms of
-- 'RVar's.
{-# LANGUAGE FlexibleContexts #-}

module Data.RVar
    ( RVar
    , runRVar, sampleReaderRVar, sampleStateRVar
    , pureRVar

    , RVarT
    , runRVarT, sampleReaderRVarT, sampleStateRVarT
    , runRVarTWith, sampleReaderRVarTWith, sampleStateRVarTWith

    , RGen(..)
    , uniformRVarT
    , uniformRangeRVarT

    , Prim(..)
    ) where


import qualified Control.Monad.IO.Class as T
import Control.Monad.Prompt (MonadPrompt(..), PromptT, runPromptT)
import Control.Monad.Reader as MTL
import Control.Monad.State as MTL
import qualified Control.Monad.Trans.Class as T
import qualified Data.Functor.Identity as T
import Data.RVar.Prim
import System.Random.Stateful
import Control.Monad (ap, liftM)

-- |An opaque type modeling a \"random variable\" - a value
-- which depends on the outcome of some random event.  'RVar's
-- can be conveniently defined by an imperative-looking style:
--
-- > normalPair =  do
-- >     u <- stdUniform
-- >     t <- stdUniform
-- >     let r = sqrt (-2 * log u)
-- >         theta = (2 * pi) * t
-- >
-- >         x = r * cos theta
-- >         y = r * sin theta
-- >     return (x,y)
--
-- OR by a more applicative style:
--
-- > logNormal = exp <$> stdNormal
--
-- Once defined (in any style), there are several ways to sample 'RVar's:
--
-- * Using an immutable pseudo-random number generator that has an instance for `RandomGen` with
--   `StateT` monad:
--
-- >>> import qualified Data.Random as Fu (uniform)
-- >>> import System.Random (mkStdGen)
-- >>> import Control.Monad.State (runState)
-- >>> runState (sampleStateRVar (Fu.uniform 1 (100 :: Integer))) (mkStdGen 2021)
-- (79,StdGen {unStdGen = SMGen 4687568268719557181 4805600293067301895})
--
-- * Using a mutable pseud-random number generator that has an instance for `StatefulGen` with
--   `ReaderT` monad.
--
-- >>> import qualified Data.Random as Fu (uniform)
-- >>> import System.Random.MWC (create)
-- >>> import Control.Monad.Reader (runReaderT)
-- >>> import qualified Data.Vector.Storable as VS
-- >>> initialize (VS.singleton 2021) >>= runReaderT (sampleReaderRVar (uniform 1 (100 :: Integer)))
-- 8
--
type RVar = RVarT T.Identity

-- | Sample random variable using `RandomGen` generator as source of entropy
pureRVar :: RandomGen g => RVar a -> g -> (a, g)
pureRVar :: forall g a. RandomGen g => RVar a -> g -> (a, g)
pureRVar RVar a
rvar g
g = forall g a.
RandomGen g =>
g -> (StateGenM g -> State g a) -> (a, g)
runStateGen g
g (forall g (m :: * -> *) a. StatefulGen g m => RVar a -> g -> m a
runRVar RVar a
rvar)

-- |\"Run\" an 'RVar' - samples the random variable from the provided
-- source of entropy.
runRVar :: StatefulGen g m => RVar a -> g -> m a
runRVar :: forall g (m :: * -> *) a. StatefulGen g m => RVar a -> g -> m a
runRVar = forall (m :: * -> *) (n :: * -> *) g a.
StatefulGen g m =>
(forall t. n t -> m t) -> RVarT n a -> g -> m a
runRVarTWith (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
T.runIdentity)

-- |@sampleRVar x@ is equivalent to @runRVar x 'StdRandom'@.
sampleReaderRVar :: (StatefulGen g m, MonadReader g m) => RVar a -> m a
sampleReaderRVar :: forall g (m :: * -> *) a.
(StatefulGen g m, MonadReader g m) =>
RVar a -> m a
sampleReaderRVar = forall (m :: * -> *) (n :: * -> *) a g.
(StatefulGen g m, MonadReader g m) =>
(forall t. n t -> m t) -> RVarT n a -> m a
sampleReaderRVarTWith (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
T.runIdentity)

sampleStateRVar :: (RandomGen g, MonadState g m) => RVar a -> m a
sampleStateRVar :: forall g (m :: * -> *) a.
(RandomGen g, MonadState g m) =>
RVar a -> m a
sampleStateRVar = forall (m :: * -> *) (n :: * -> *) a g.
(RandomGen g, MonadState g m) =>
(forall t. n t -> m t) -> RVarT n a -> m a
sampleStateRVarTWith (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
T.runIdentity)

-- |A random variable with access to operations in an underlying monad.  Useful
-- examples include any form of state for implementing random processes with hysteresis,
-- or writer monads for implementing tracing of complicated algorithms.
--
-- For example, a simple random walk can be implemented as an 'RVarT' 'IO' value:
--
-- > rwalkIO :: IO (RVarT IO Double)
-- > rwalkIO d = do
-- >     lastVal <- newIORef 0
-- >
-- >     let x = do
-- >             prev    <- lift (readIORef lastVal)
-- >             change  <- rvarT StdNormal
-- >
-- >             let new = prev + change
-- >             lift (writeIORef lastVal new)
-- >             return new
-- >
-- >     return x
--
-- To run the random walk it must first be initialized, after which it can be sampled as usual:
--
-- > do
-- >     rw <- rwalkIO
-- >     x <- sampleRVarT rw
-- >     y <- sampleRVarT rw
-- >     ...
--
-- The same random-walk process as above can be implemented using MTL types
-- as follows (using @import Control.Monad.Trans as MTL@):
--
-- > rwalkState :: RVarT (State Double) Double
-- > rwalkState = do
-- >     prev <- MTL.lift get
-- >     change  <- rvarT StdNormal
-- >
-- >     let new = prev + change
-- >     MTL.lift (put new)
-- >     return new
--
-- Invocation is straightforward (although a bit noisy) if you're used to MTL:
--
-- > rwalk :: Int -> Double -> StdGen -> ([Double], StdGen)
-- > rwalk count start gen =
-- >     flip evalState start .
-- >         flip runStateT gen .
-- >             sampleRVarTWith MTL.lift $
-- >                 replicateM count rwalkState
newtype RVarT m a = RVarT { forall (m :: * -> *) a. RVarT m a -> PromptT Prim m a
unRVarT :: PromptT Prim m a }

runRVarT :: StatefulGen g m => RVarT m a -> g -> m a
runRVarT :: forall g (m :: * -> *) a. StatefulGen g m => RVarT m a -> g -> m a
runRVarT = forall (m :: * -> *) (n :: * -> *) g a.
StatefulGen g m =>
(forall t. n t -> m t) -> RVarT n a -> g -> m a
runRVarTWith forall a. a -> a
id


sampleStateRVarT :: (RandomGen g, MonadState g m) => RVarT m a -> m a
sampleStateRVarT :: forall g (m :: * -> *) a.
(RandomGen g, MonadState g m) =>
RVarT m a -> m a
sampleStateRVarT RVarT m a
rvar = forall g (m :: * -> *) a. StatefulGen g m => RVarT m a -> g -> m a
runRVarT RVarT m a
rvar forall g. StateGenM g
StateGenM

sampleReaderRVarT :: (StatefulGen g m, MonadReader g m) => RVarT m a -> m a
sampleReaderRVarT :: forall g (m :: * -> *) a.
(StatefulGen g m, MonadReader g m) =>
RVarT m a -> m a
sampleReaderRVarT RVarT m a
rvar = forall r (m :: * -> *). MonadReader r m => m r
ask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall g (m :: * -> *) a. StatefulGen g m => RVarT m a -> g -> m a
runRVarT RVarT m a
rvar

-- | \"Runs\" an 'RVarT', sampling the random variable it defines.
--
-- The first argument lifts the base monad into the sampling monad.  This
-- operation must obey the \"monad transformer\" laws:
--
-- > lift . return = return
-- > lift (x >>= f) = (lift x) >>= (lift . f)
--
-- One example of a useful non-standard lifting would be one that takes
-- @State s@ to another monad with a different state representation (such as
-- @IO@ with the state mapped to an @IORef@):
--
-- > embedState :: (Monad m) => m s -> (s -> m ()) -> State s a -> m a
-- > embedState get put = \m -> do
-- >     s <- get
-- >     (res,s) <- return (runState m s)
-- >     put s
-- >     return res
--
-- The ability to lift is very important - without it, every 'RVar' would have
-- to either be given access to the full capability of the monad in which it
-- will eventually be sampled (which, incidentally, would also have to be
-- monomorphic so you couldn't sample one 'RVar' in more than one monad)
-- or functions manipulating 'RVar's would have to use higher-ranked
-- types to enforce the same kind of isolation and polymorphism.
{-# INLINE runRVarTWith #-}
runRVarTWith :: forall m n g a. StatefulGen g m => (forall t. n t -> m t) -> RVarT n a -> g -> m a
runRVarTWith :: forall (m :: * -> *) (n :: * -> *) g a.
StatefulGen g m =>
(forall t. n t -> m t) -> RVarT n a -> g -> m a
runRVarTWith forall t. n t -> m t
liftN (RVarT PromptT Prim n a
m) g
gen = forall (p :: * -> *) (m :: * -> *) r b.
(r -> b)
-> (forall a. p a -> (a -> b) -> b)
-> (forall a. m a -> (a -> b) -> b)
-> PromptT p m r
-> b
runPromptT forall (m :: * -> *) a. Monad m => a -> m a
return forall t. Prim t -> (t -> m a) -> m a
bindP forall t. n t -> (t -> m a) -> m a
bindN PromptT Prim n a
m
    where
        bindP :: forall t. (Prim t -> (t -> m a) -> m a)
        bindP :: forall t. Prim t -> (t -> m a) -> m a
bindP Prim t
prim t -> m a
cont = forall g (m :: * -> *) t. StatefulGen g m => Prim t -> g -> m t
uniformPrimM Prim t
prim g
gen forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont

        bindN :: forall t. n t -> (t -> m a) -> m a
        bindN :: forall t. n t -> (t -> m a) -> m a
bindN n t
nExp t -> m a
cont = forall t. n t -> m t
liftN n t
nExp forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont

{-# INLINE uniformPrimM #-}
uniformPrimM :: StatefulGen g m => Prim t -> g -> m t
uniformPrimM :: forall g (m :: * -> *) t. StatefulGen g m => Prim t -> g -> m t
uniformPrimM Prim t
prim g
g =
    case Prim t
prim of
        Prim t
PrimWord8             -> forall g (m :: * -> *). StatefulGen g m => g -> m Word8
uniformWord8 g
g
        Prim t
PrimWord16            -> forall g (m :: * -> *). StatefulGen g m => g -> m Word16
uniformWord16 g
g
        Prim t
PrimWord32            -> forall g (m :: * -> *). StatefulGen g m => g -> m Word32
uniformWord32 g
g
        Prim t
PrimWord64            -> forall g (m :: * -> *). StatefulGen g m => g -> m Word64
uniformWord64 g
g
        PrimShortByteString Int
n -> forall g (m :: * -> *).
StatefulGen g m =>
Int -> g -> m ShortByteString
uniformShortByteString Int
n g
g


-- |@sampleRVarTWith lift x@ is equivalent to @runRVarTWith lift x 'StdRandom'@.
{-# INLINE sampleReaderRVarTWith #-}
sampleReaderRVarTWith ::
       forall m n a g. (StatefulGen g m, MonadReader g m)
    => (forall t. n t -> m t)
    -> RVarT n a
    -> m a
sampleReaderRVarTWith :: forall (m :: * -> *) (n :: * -> *) a g.
(StatefulGen g m, MonadReader g m) =>
(forall t. n t -> m t) -> RVarT n a -> m a
sampleReaderRVarTWith forall t. n t -> m t
liftN (RVarT PromptT Prim n a
m) = forall (p :: * -> *) (m :: * -> *) r b.
(r -> b)
-> (forall a. p a -> (a -> b) -> b)
-> (forall a. m a -> (a -> b) -> b)
-> PromptT p m r
-> b
runPromptT forall (m :: * -> *) a. Monad m => a -> m a
return forall t. Prim t -> (t -> m a) -> m a
bindP forall t. n t -> (t -> m a) -> m a
bindN PromptT Prim n a
m
    where
        bindP :: forall t. (Prim t -> (t -> m a) -> m a)
        bindP :: forall t. Prim t -> (t -> m a) -> m a
bindP Prim t
prim t -> m a
cont = forall r (m :: * -> *). MonadReader r m => m r
ask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall g (m :: * -> *) t. StatefulGen g m => Prim t -> g -> m t
uniformPrimM Prim t
prim forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont

        bindN :: forall t. n t -> (t -> m a) -> m a
        bindN :: forall t. n t -> (t -> m a) -> m a
bindN n t
nExp t -> m a
cont = forall t. n t -> m t
liftN n t
nExp forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont


-- |@sampleRVarTWith lift x@ is equivalent to @runRVarTWith lift x 'StdRandom'@.
{-# INLINE sampleStateRVarTWith #-}
sampleStateRVarTWith ::
       forall m n a g. (RandomGen g, MonadState g m)
    => (forall t. n t -> m t)
    -> RVarT n a
    -> m a
sampleStateRVarTWith :: forall (m :: * -> *) (n :: * -> *) a g.
(RandomGen g, MonadState g m) =>
(forall t. n t -> m t) -> RVarT n a -> m a
sampleStateRVarTWith forall t. n t -> m t
liftN (RVarT PromptT Prim n a
m) = forall (p :: * -> *) (m :: * -> *) r b.
(r -> b)
-> (forall a. p a -> (a -> b) -> b)
-> (forall a. m a -> (a -> b) -> b)
-> PromptT p m r
-> b
runPromptT forall (m :: * -> *) a. Monad m => a -> m a
return forall t. Prim t -> (t -> m a) -> m a
bindP forall t. n t -> (t -> m a) -> m a
bindN PromptT Prim n a
m
    where
        bindP :: forall t. (Prim t -> (t -> m a) -> m a)
        bindP :: forall t. Prim t -> (t -> m a) -> m a
bindP Prim t
prim t -> m a
cont = forall g (m :: * -> *) t. StatefulGen g m => Prim t -> g -> m t
uniformPrimM Prim t
prim forall g. StateGenM g
StateGenM forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont

        bindN :: forall t. n t -> (t -> m a) -> m a
        bindN :: forall t. n t -> (t -> m a) -> m a
bindN n t
nExp t -> m a
cont = forall t. n t -> m t
liftN n t
nExp forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> m a
cont

instance Functor (RVarT n) where
    fmap :: forall a b. (a -> b) -> RVarT n a -> RVarT n b
fmap = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM

instance Monad (RVarT n) where
    (RVarT PromptT Prim n a
m) >>= :: forall a b. RVarT n a -> (a -> RVarT n b) -> RVarT n b
>>= a -> RVarT n b
k = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT (PromptT Prim n a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
x -> a
x seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. RVarT m a -> PromptT Prim m a
unRVarT (a -> RVarT n b
k a
x))

instance Applicative (RVarT n) where
    pure :: forall a. a -> RVarT n a
pure a
x = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! a
x)
    <*> :: forall a b. RVarT n (a -> b) -> RVarT n a -> RVarT n b
(<*>)  = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance MonadPrompt Prim (RVarT n) where
    prompt :: forall a. Prim a -> RVarT n a
prompt = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt

instance T.MonadTrans RVarT where
    lift :: forall (m :: * -> *) a. Monad m => m a -> RVarT m a
lift m a
m = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
MTL.lift m a
m)

instance T.MonadIO m => T.MonadIO (RVarT m) where
    liftIO :: forall a. IO a -> RVarT m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
T.lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
T.liftIO

#ifndef MTL2

instance MTL.MonadTrans RVarT where
    lift m = RVarT (MTL.lift m)

instance MTL.MonadIO m => MTL.MonadIO (RVarT m) where
    liftIO = MTL.lift . MTL.liftIO

#endif

data RGen = RGen

instance StatefulGen RGen (RVarT m) where
    uniformWord8 :: RGen -> RVarT m Word8
uniformWord8 RGen
RGen = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall a b. (a -> b) -> a -> b
$ forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt Prim Word8
PrimWord8
    {-# INLINE uniformWord8 #-}
    uniformWord16 :: RGen -> RVarT m Word16
uniformWord16 RGen
RGen = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall a b. (a -> b) -> a -> b
$ forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt Prim Word16
PrimWord16
    {-# INLINE uniformWord16 #-}
    uniformWord32 :: RGen -> RVarT m Word32
uniformWord32 RGen
RGen = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall a b. (a -> b) -> a -> b
$ forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt Prim Word32
PrimWord32
    {-# INLINE uniformWord32 #-}
    uniformWord64 :: RGen -> RVarT m Word64
uniformWord64 RGen
RGen = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall a b. (a -> b) -> a -> b
$ forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt Prim Word64
PrimWord64
    {-# INLINE uniformWord64 #-}
    uniformShortByteString :: Int -> RGen -> RVarT m ShortByteString
uniformShortByteString Int
n RGen
RGen = forall (m :: * -> *) a. PromptT Prim m a -> RVarT m a
RVarT forall a b. (a -> b) -> a -> b
$ forall (p :: * -> *) (m :: * -> *) a. MonadPrompt p m => p a -> m a
prompt (Int -> Prim ShortByteString
PrimShortByteString Int
n)
    {-# INLINE uniformShortByteString #-}


uniformRVarT :: Uniform a => RVarT m a
uniformRVarT :: forall a (m :: * -> *). Uniform a => RVarT m a
uniformRVarT = forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM RGen
RGen
{-# INLINE uniformRVarT #-}

uniformRangeRVarT :: UniformRange a => (a, a) -> RVarT m a
uniformRangeRVarT :: forall a (m :: * -> *). UniformRange a => (a, a) -> RVarT m a
uniformRangeRVarT (a, a)
r = forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (a, a)
r RGen
RGen
{-# INLINE uniformRangeRVarT #-}