{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}

-- |
-- Module      : Control.Monad.Bayes.Sampler
-- Description : Pseudo-random sampling monads
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'SamplerIO' and 'SamplerST' are instances of 'MonadDistribution'. Apply a 'MonadFactor'
-- transformer to obtain a 'MonadMeasure' that can execute probabilistic models.
module Control.Monad.Bayes.Sampler.Strict
  ( SamplerT (..),
    SamplerIO,
    SamplerST,
    sampleIO,
    sampleIOfixed,
    sampleWith,
    sampleSTfixed,
    sampleMean,
    sampler,
  )
where

import Control.Foldl qualified as F hiding (random)
import Control.Monad.Bayes.Class
  ( MonadDistribution
      ( bernoulli,
        beta,
        categorical,
        gamma,
        geometric,
        normal,
        random,
        uniform
      ),
  )
import Control.Monad.Reader (MonadIO, ReaderT (..))
import Control.Monad.ST (ST)
import Control.Monad.Trans (MonadTrans)
import Numeric.Log (Log (ln))
import System.Random.MWC.Distributions qualified as MWC
import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM)

-- | The sampling interpretation of a probabilistic program
-- Here m is typically IO or ST
newtype SamplerT g m a = SamplerT {forall g (m :: * -> *) a. SamplerT g m a -> ReaderT g m a
runSamplerT :: ReaderT g m a} deriving ((forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b)
-> (forall a b. a -> SamplerT g m b -> SamplerT g m a)
-> Functor (SamplerT g m)
forall a b. a -> SamplerT g m b -> SamplerT g m a
forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b
forall g (m :: * -> *) a b.
Functor m =>
a -> SamplerT g m b -> SamplerT g m a
forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT g m a -> SamplerT g m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT g m a -> SamplerT g m b
fmap :: forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b
$c<$ :: forall g (m :: * -> *) a b.
Functor m =>
a -> SamplerT g m b -> SamplerT g m a
<$ :: forall a b. a -> SamplerT g m b -> SamplerT g m a
Functor, Functor (SamplerT g m)
Functor (SamplerT g m) =>
(forall a. a -> SamplerT g m a)
-> (forall a b.
    SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b)
-> (forall a b c.
    (a -> b -> c)
    -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a)
-> Applicative (SamplerT g m)
forall a. a -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall a b.
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
forall a b c.
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
forall g (m :: * -> *). Applicative m => Functor (SamplerT g m)
forall g (m :: * -> *) a. Applicative m => a -> SamplerT g m a
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m a
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall g (m :: * -> *) a. Applicative m => a -> SamplerT g m a
pure :: forall a. a -> SamplerT g m a
$c<*> :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
<*> :: forall a b.
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
$cliftA2 :: forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
liftA2 :: forall a b c.
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
$c*> :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
*> :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
$c<* :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m a
<* :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a
Applicative, Applicative (SamplerT g m)
Applicative (SamplerT g m) =>
(forall a b.
 SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b)
-> (forall a. a -> SamplerT g m a)
-> Monad (SamplerT g m)
forall a. a -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall a b.
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
forall g (m :: * -> *). Monad m => Applicative (SamplerT g m)
forall g (m :: * -> *) a. Monad m => a -> SamplerT g m a
forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
>>= :: forall a b.
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
$c>> :: forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
>> :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
$creturn :: forall g (m :: * -> *) a. Monad m => a -> SamplerT g m a
return :: forall a. a -> SamplerT g m a
Monad, Monad (SamplerT g m)
Monad (SamplerT g m) =>
(forall a. IO a -> SamplerT g m a) -> MonadIO (SamplerT g m)
forall a. IO a -> SamplerT g m a
forall g (m :: * -> *). MonadIO m => Monad (SamplerT g m)
forall g (m :: * -> *) a. MonadIO m => IO a -> SamplerT g m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall g (m :: * -> *) a. MonadIO m => IO a -> SamplerT g m a
liftIO :: forall a. IO a -> SamplerT g m a
MonadIO, (forall (m :: * -> *). Monad m => Monad (SamplerT g m)) =>
(forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a)
-> MonadTrans (SamplerT g)
forall g (m :: * -> *). Monad m => Monad (SamplerT g m)
forall g (m :: * -> *) a. Monad m => m a -> SamplerT g m a
forall (m :: * -> *). Monad m => Monad (SamplerT g m)
forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall g (m :: * -> *) a. Monad m => m a -> SamplerT g m a
lift :: forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a
MonadTrans)

-- | convenient type synonym to show specializations of SamplerT
-- to particular pairs of monad and RNG
type SamplerIO = SamplerT (IOGenM StdGen) IO

-- | convenient type synonym to show specializations of SamplerT
-- to particular pairs of monad and RNG
type SamplerST s = SamplerT (STGenM StdGen s) (ST s)

instance (StatefulGen g m) => MonadDistribution (SamplerT g m) where
  random :: SamplerT g m Double
random = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT g -> m Double
forall g (m :: * -> *). StatefulGen g m => g -> m Double
uniformDouble01M)

  uniform :: Double -> Double -> SamplerT g m Double
uniform Double
a Double
b = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ (Double, Double) -> g -> m Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(Double, Double) -> g -> m Double
uniformRM (Double
a, Double
b))
  normal :: Double -> Double -> SamplerT g m Double
normal Double
m Double
s = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT (Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.normal Double
m Double
s))
  gamma :: Double -> Double -> SamplerT g m Double
gamma Double
shape Double
scale = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.gamma Double
shape Double
scale)
  beta :: Double -> Double -> SamplerT g m Double
beta Double
a Double
b = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.beta Double
a Double
b)

  bernoulli :: Double -> SamplerT g m Bool
bernoulli Double
p = ReaderT g m Bool -> SamplerT g m Bool
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Bool) -> ReaderT g m Bool
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Bool) -> ReaderT g m Bool)
-> (g -> m Bool) -> ReaderT g m Bool
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.bernoulli Double
p)
  categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> SamplerT g m Int
categorical v Double
ps = ReaderT g m Int -> SamplerT g m Int
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ v Double -> g -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.categorical v Double
ps)
  geometric :: Double -> SamplerT g m Int
geometric Double
p = ReaderT g m Int -> SamplerT g m Int
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Int
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Int
MWC.geometric0 Double
p)

-- | Sample with a random number generator of your choice e.g. the one
-- from `System.Random`.
--
-- >>> import Control.Monad.Bayes.Class
-- >>> import System.Random.Stateful hiding (random)
-- >>> newIOGenM (mkStdGen 1729) >>= sampleWith random
-- 4.690861245089605e-2
sampleWith :: SamplerT g m a -> g -> m a
sampleWith :: forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith (SamplerT ReaderT g m a
m) = ReaderT g m a -> g -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT g m a
m

-- | initialize random seed using system entropy, and sample
sampleIO, sampler :: SamplerIO a -> IO a
sampleIO :: forall a. SamplerIO a -> IO a
sampleIO SamplerIO a
x = IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
initStdGen IO StdGen -> (StdGen -> IO (IOGenM StdGen)) -> IO (IOGenM StdGen)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerIO a
x
sampler :: forall a. SamplerIO a -> IO a
sampler = SamplerIO a -> IO a
forall a. SamplerIO a -> IO a
sampleIO

-- | Run the sampler with a fixed random seed
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed :: forall a. SamplerIO a -> IO a
sampleIOfixed SamplerIO a
x = StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM (Int -> StdGen
mkStdGen Int
1729) IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerIO a
x

-- | Run the sampler with a fixed random seed
sampleSTfixed :: SamplerST s b -> ST s b
sampleSTfixed :: forall s b. SamplerST s b -> ST s b
sampleSTfixed SamplerST s b
x = StdGen -> ST s (STGenM StdGen s)
forall g s. g -> ST s (STGenM g s)
newSTGenM (Int -> StdGen
mkStdGen Int
1729) ST s (STGenM StdGen s) -> (STGenM StdGen s -> ST s b) -> ST s b
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerST s b -> STGenM StdGen s -> ST s b
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerST s b
x

sampleMean :: [(Double, Log Double)] -> Double
sampleMean :: [(Double, Log Double)] -> Double
sampleMean [(Double, Log Double)]
samples =
  let z :: Fold (a, Log Double) Double
z = ((a, Log Double) -> Double)
-> Fold Double Double -> Fold (a, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Log Double
forall a. Floating a => a -> a
exp (Log Double -> Log Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd) Fold Double Double
forall a. Num a => Fold a a
F.sum
      w :: Fold (Double, Log Double) Double
w = (((Double, Log Double) -> Double)
-> Fold Double Double -> Fold (Double, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (\(Double
x, Log Double
y) -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Log Double
forall a. Floating a => a -> a
exp Log Double
y)) Fold Double Double
forall a. Num a => Fold a a
F.sum)
      s :: Fold (Double, Log Double) Double
s = Double -> Double -> Double
forall a. Fractional a => a -> a -> a
(/) (Double -> Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) (Double -> Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fold (Double, Log Double) Double
w Fold (Double, Log Double) (Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) Double
forall a b.
Fold (Double, Log Double) (a -> b)
-> Fold (Double, Log Double) a -> Fold (Double, Log Double) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Fold (Double, Log Double) Double
forall {a}. Fold (a, Log Double) Double
z
   in Fold (Double, Log Double) Double
-> [(Double, Log Double)] -> Double
forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
F.fold Fold (Double, Log Double) Double
s [(Double, Log Double)]
samples