{-# language GeneralizedNewtypeDeriving #-}
{-# options_ghc -Wno-unused-imports #-}
{-|
Random samplers for few common distributions, with an interface similar to that of @mwc-probability@.

= Usage

Compose your random sampler out of simpler ones thanks to the Applicative and Monad interface, e.g. this is how you would declare and sample a binary mixture of Gaussian random variables:

@
import Control.Monad (replicateM)
import System.Random.SplitMix.Distributions (Gen, sample, bernoulli, normal)

process :: `Gen` Double
process = do
  coin <- `bernoulli` 0.7
  if coin
    then
      `normal` 0 2
    else
      normal 3 1

dataset :: [Double]
dataset = `sample` 1234 $ replicateM 20 process
@

and sample your data in a pure (`sample`) or monadic (`sampleT`) setting.

== Implementation details

The library is built on top of @splitmix@, so the caveats on safety and performance that apply there are relevant here as well.


-}
module System.Random.SplitMix.Distributions (
  -- * Distributions
  -- ** Continuous
  stdUniform, uniformR,
  exponential,
  stdNormal, normal,
  beta,
  gamma,
  -- ** Discrete
  bernoulli,
  -- * PRNG
  -- ** Pure
  Gen, sample,
  -- ** Monadic
  GenT, sampleT,
  withGen
                                            ) where

import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Functor.Identity (Identity(..))
import GHC.Word (Word64)

-- erf
import Data.Number.Erf (InvErf(..))
-- mtl
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.State (MonadState(..), modify)
-- splitmix
import System.Random.SplitMix (SMGen, mkSMGen, splitSMGen, nextInt, nextInteger, nextDouble)
-- transformers
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)

-- | Random generator
--
-- wraps 'splitmix' state-passing inside a 'StateT' monad
--
-- useful for embedding random generation inside a larger effect stack
newtype GenT m a = GenT { GenT m a -> StateT SMGen m a
unGen :: StateT SMGen m a } deriving (a -> GenT m b -> GenT m a
(a -> b) -> GenT m a -> GenT m b
(forall a b. (a -> b) -> GenT m a -> GenT m b)
-> (forall a b. a -> GenT m b -> GenT m a) -> Functor (GenT m)
forall a b. a -> GenT m b -> GenT m a
forall a b. (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> GenT m b -> GenT m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
fmap :: (a -> b) -> GenT m a -> GenT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
Functor, Functor (GenT m)
a -> GenT m a
Functor (GenT m)
-> (forall a. a -> GenT m a)
-> (forall a b. GenT m (a -> b) -> GenT m a -> GenT m b)
-> (forall a b c.
    (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m a)
-> Applicative (GenT m)
GenT m a -> GenT m b -> GenT m b
GenT m a -> GenT m b -> GenT m a
GenT m (a -> b) -> GenT m a -> GenT m b
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m (a -> b) -> GenT m a -> GenT m b
forall a b c. (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (m :: * -> *). Monad m => Functor (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: GenT m a -> GenT m b -> GenT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
*> :: GenT m a -> GenT m b -> GenT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
liftA2 :: (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
<*> :: GenT m (a -> b) -> GenT m a -> GenT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
pure :: a -> GenT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> GenT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (GenT m)
Applicative, Applicative (GenT m)
a -> GenT m a
Applicative (GenT m)
-> (forall a b. GenT m a -> (a -> GenT m b) -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a. a -> GenT m a)
-> Monad (GenT m)
GenT m a -> (a -> GenT m b) -> GenT m b
GenT m a -> GenT m b -> GenT m b
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *). Monad m => Applicative (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> GenT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> GenT m a
>> :: GenT m a -> GenT m b -> GenT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
>>= :: GenT m a -> (a -> GenT m b) -> GenT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (GenT m)
Monad, MonadState SMGen, m a -> GenT m a
(forall (m :: * -> *) a. Monad m => m a -> GenT m a)
-> MonadTrans GenT
forall (m :: * -> *) a. Monad m => m a -> GenT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> GenT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> GenT m a
MonadTrans, Monad (GenT m)
Monad (GenT m) -> (forall a. IO a -> GenT m a) -> MonadIO (GenT m)
IO a -> GenT m a
forall a. IO a -> GenT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (GenT m)
forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
liftIO :: IO a -> GenT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (GenT m)
MonadIO)

-- | Pure random generation
type Gen = GenT Identity

-- | Monadic evaluation
sampleT :: Monad m =>
            Word64 -- ^ random seed
         -> GenT m a -> m a
sampleT :: Word64 -> GenT m a -> m a
sampleT Word64
seed GenT m a
gg = StateT SMGen m a -> SMGen -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (GenT m a -> StateT SMGen m a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen GenT m a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)

-- | Pure evaluation
sample :: Word64 -- ^ random seed
        -> Gen a
        -> a
sample :: Word64 -> Gen a -> a
sample Word64
seed Gen a
gg = State SMGen a -> SMGen -> a
forall s a. State s a -> s -> a
evalState (Gen a -> State SMGen a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen Gen a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)


-- | Bernoulli trial
bernoulli :: Double -- ^ bias parameter \( 0 \lt p \lt 1 \)
          -> Gen Bool
bernoulli :: Double -> Gen Bool
bernoulli Double
p = (SMGen -> (Bool, SMGen)) -> Gen Bool
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p)

-- | Uniform between two values
uniformR :: Double -- ^ low
         -> Double -- ^ high
         -> Gen Double
uniformR :: Double -> Double -> Gen Double
uniformR Double
lo Double
hi = Double -> Double
scale (Double -> Double) -> Gen Double -> Gen Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform
  where
    scale :: Double -> Double
scale Double
x = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lo

-- | Standard normal
stdNormal :: Gen Double
stdNormal :: Gen Double
stdNormal = Double -> Double -> Gen Double
normal Double
0 Double
1

-- | Uniform in [0, 1)
stdUniform :: Gen Double
stdUniform :: Gen Double
stdUniform = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (Double, SMGen)
nextDouble

-- | Beta distribution, from two standard uniform samples
beta :: Double -- ^ shape parameter \( \alpha \gt 0 \) 
     -> Double -- ^ shape parameter \( \beta \gt 0 \)
     -> Gen Double
beta :: Double -> Double -> Gen Double
beta Double
a Double
b = Gen Double
go
  where
    go :: Gen Double
go = do
      (Double
y1, Double
y2) <- GenT Identity (Double, Double)
sample2
      if
        Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
1
        then Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
y1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2))
        else Gen Double
go
    sample2 :: GenT Identity (Double, Double)
sample2 = Double -> Double -> (Double, Double)
f (Double -> Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform GenT Identity (Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform
      where
        f :: Double -> Double -> (Double, Double)
f Double
u1 Double
u2 = (Double
u1 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
a), Double
u2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
b))

-- | Gamma distribution, using Ahrens-Dieter accept-reject (algorithm GD):
--
-- Ahrens, J. H.; Dieter, U (January 1982). "Generating gamma variates by a modified rejection technique". Communications of the ACM. 25 (1): 47–54
gamma :: Double -- ^ shape parameter \( k \gt 0 \)
      -> Double -- ^ scale parameter \( \theta \gt 0 \)
      -> Gen Double
gamma :: Double -> Double -> Gen Double
gamma Double
k Double
th = do
  Double
xi <- Gen Double
sampleXi
  [Double]
us <- Int -> Gen Double -> GenT Identity [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Gen Double -> Gen Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform)
  Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Gen Double) -> Double -> Gen Double
forall a b. (a -> b) -> a -> b
$ Double
th Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
us
  where
    sampleXi :: Gen Double
sampleXi = do
      (Double
xi, Double
eta) <- GenT Identity (Double, Double)
sample2
      if Double
eta Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi)
        then Gen Double
sampleXi
        else Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
xi
    (Int
n, Double
delta) = (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
k, Double
k Double -> Double -> Double
forall a. Num a => a -> a -> a
- Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    ee :: Double
ee = Double -> Double
forall a. Floating a => a -> a
exp Double
1
    sample2 :: GenT Identity (Double, Double)
sample2 = Double -> Double -> Double -> (Double, Double)
f (Double -> Double -> Double -> (Double, Double))
-> Gen Double
-> GenT Identity (Double -> Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform GenT Identity (Double -> Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double -> (Double, Double))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform GenT Identity (Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform
      where
        f :: Double -> Double -> Double -> (Double, Double)
f Double
u Double
v Double
w
          | Double
u Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
ee Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
ee Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
delta) =
            let xi :: Double
xi = Double
v Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
delta)
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1))
          | Bool
otherwise =
            let xi :: Double
xi = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log Double
v
            in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi))


-- | Normal distribution
normal :: Double -- ^ mean
       -> Double -- ^ standard deviation \( \sigma \gt 0 \)
       -> Gen Double
normal :: Double -> Double -> Gen Double
normal Double
mu Double
sig = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig)

-- | Exponential distribution
exponential :: Double -- ^ rate parameter \( \lambda > 0 \)
            -> Gen Double
exponential :: Double -> Gen Double
exponential Double
l = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Double, SMGen)
exponentialF Double
l)

-- | Wrap a 'splitmix' PRNG function
withGen :: Monad m =>
           (SMGen -> (a, SMGen)) -- ^ explicit generator passing (e.g. 'nextDouble')
        -> GenT m a
withGen :: (SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (a, SMGen)
f = StateT SMGen m a -> GenT m a
forall (m :: * -> *) a. StateT SMGen m a -> GenT m a
GenT (StateT SMGen m a -> GenT m a) -> StateT SMGen m a -> GenT m a
forall a b. (a -> b) -> a -> b
$ do
  SMGen
gen <- StateT SMGen m SMGen
forall s (m :: * -> *). MonadState s m => m s
get
  let
    (a
b, SMGen
gen') = SMGen -> (a, SMGen)
f SMGen
gen
  SMGen -> StateT SMGen m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put SMGen
gen'
  a -> StateT SMGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b

exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF Double
l SMGen
g = (Double -> Double -> Double
forall a. Floating a => a -> a -> a
exponentialICDF Double
l Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig SMGen
g = (Double -> Double -> Double -> Double
forall a. InvErf a => a -> a -> a -> a
normalICDF Double
mu Double
sig Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g

bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p SMGen
g = (Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p , SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g


-- | inverse CDF of normal rv
normalICDF :: InvErf a =>
              a -- ^ mean
           -> a -- ^ std dev
           -> a -> a
normalICDF :: a -> a -> a -> a
normalICDF a
mu a
sig a
p = a
mu a -> a -> a
forall a. Num a => a -> a -> a
+ a
sig a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. InvErf a => a -> a
inverf (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

-- | inverse CDF of exponential rv
exponentialICDF :: Floating a =>
                   a -- ^ rate
                -> a -> a
exponentialICDF :: a -> a -> a
exponentialICDF a
l a
p = (- a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
l) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
p)