{-# language GeneralizedNewtypeDeriving #-}
{-# options_ghc -Wno-unused-imports #-}
module System.Random.SplitMix.Distributions (
stdUniform, uniformR,
exponential,
stdNormal, normal,
beta,
gamma,
bernoulli,
Gen, sample,
GenT, sampleT,
withGen
) where
import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Functor.Identity (Identity(..))
import GHC.Word (Word64)
import Data.Number.Erf (InvErf(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.State (MonadState(..), modify)
import System.Random.SplitMix (SMGen, mkSMGen, splitSMGen, nextInt, nextInteger, nextDouble)
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)
newtype GenT m a = GenT { GenT m a -> StateT SMGen m a
unGen :: StateT SMGen m a } deriving (a -> GenT m b -> GenT m a
(a -> b) -> GenT m a -> GenT m b
(forall a b. (a -> b) -> GenT m a -> GenT m b)
-> (forall a b. a -> GenT m b -> GenT m a) -> Functor (GenT m)
forall a b. a -> GenT m b -> GenT m a
forall a b. (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> GenT m b -> GenT m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> GenT m b -> GenT m a
fmap :: (a -> b) -> GenT m a -> GenT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> GenT m a -> GenT m b
Functor, Functor (GenT m)
a -> GenT m a
Functor (GenT m)
-> (forall a. a -> GenT m a)
-> (forall a b. GenT m (a -> b) -> GenT m a -> GenT m b)
-> (forall a b c.
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m a)
-> Applicative (GenT m)
GenT m a -> GenT m b -> GenT m b
GenT m a -> GenT m b -> GenT m a
GenT m (a -> b) -> GenT m a -> GenT m b
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m (a -> b) -> GenT m a -> GenT m b
forall a b c. (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (m :: * -> *). Monad m => Functor (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: GenT m a -> GenT m b -> GenT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m a
*> :: GenT m a -> GenT m b -> GenT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
liftA2 :: (a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> GenT m a -> GenT m b -> GenT m c
<*> :: GenT m (a -> b) -> GenT m a -> GenT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
GenT m (a -> b) -> GenT m a -> GenT m b
pure :: a -> GenT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> GenT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (GenT m)
Applicative, Applicative (GenT m)
a -> GenT m a
Applicative (GenT m)
-> (forall a b. GenT m a -> (a -> GenT m b) -> GenT m b)
-> (forall a b. GenT m a -> GenT m b -> GenT m b)
-> (forall a. a -> GenT m a)
-> Monad (GenT m)
GenT m a -> (a -> GenT m b) -> GenT m b
GenT m a -> GenT m b -> GenT m b
forall a. a -> GenT m a
forall a b. GenT m a -> GenT m b -> GenT m b
forall a b. GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *). Monad m => Applicative (GenT m)
forall (m :: * -> *) a. Monad m => a -> GenT m a
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> GenT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> GenT m a
>> :: GenT m a -> GenT m b -> GenT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> GenT m b -> GenT m b
>>= :: GenT m a -> (a -> GenT m b) -> GenT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
GenT m a -> (a -> GenT m b) -> GenT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (GenT m)
Monad, MonadState SMGen, m a -> GenT m a
(forall (m :: * -> *) a. Monad m => m a -> GenT m a)
-> MonadTrans GenT
forall (m :: * -> *) a. Monad m => m a -> GenT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> GenT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> GenT m a
MonadTrans, Monad (GenT m)
Monad (GenT m) -> (forall a. IO a -> GenT m a) -> MonadIO (GenT m)
IO a -> GenT m a
forall a. IO a -> GenT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (GenT m)
forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
liftIO :: IO a -> GenT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> GenT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (GenT m)
MonadIO)
type Gen = GenT Identity
sampleT :: Monad m =>
Word64
-> GenT m a -> m a
sampleT :: Word64 -> GenT m a -> m a
sampleT Word64
seed GenT m a
gg = StateT SMGen m a -> SMGen -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (GenT m a -> StateT SMGen m a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen GenT m a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)
sample :: Word64
-> Gen a
-> a
sample :: Word64 -> Gen a -> a
sample Word64
seed Gen a
gg = State SMGen a -> SMGen -> a
forall s a. State s a -> s -> a
evalState (Gen a -> State SMGen a
forall (m :: * -> *) a. GenT m a -> StateT SMGen m a
unGen Gen a
gg) (Word64 -> SMGen
mkSMGen Word64
seed)
bernoulli :: Double
-> Gen Bool
bernoulli :: Double -> Gen Bool
bernoulli Double
p = (SMGen -> (Bool, SMGen)) -> Gen Bool
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p)
uniformR :: Double
-> Double
-> Gen Double
uniformR :: Double -> Double -> Gen Double
uniformR Double
lo Double
hi = Double -> Double
scale (Double -> Double) -> Gen Double -> Gen Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform
where
scale :: Double -> Double
scale Double
x = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lo
stdNormal :: Gen Double
stdNormal :: Gen Double
stdNormal = Double -> Double -> Gen Double
normal Double
0 Double
1
stdUniform :: Gen Double
stdUniform :: Gen Double
stdUniform = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (Double, SMGen)
nextDouble
beta :: Double
-> Double
-> Gen Double
beta :: Double -> Double -> Gen Double
beta Double
a Double
b = Gen Double
go
where
go :: Gen Double
go = do
(Double
y1, Double
y2) <- GenT Identity (Double, Double)
sample2
if
Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
1
then Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
y1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
y1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y2))
else Gen Double
go
sample2 :: GenT Identity (Double, Double)
sample2 = Double -> Double -> (Double, Double)
f (Double -> Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform GenT Identity (Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform
where
f :: Double -> Double -> (Double, Double)
f Double
u1 Double
u2 = (Double
u1 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
a), Double
u2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
b))
gamma :: Double
-> Double
-> Gen Double
gamma :: Double -> Double -> Gen Double
gamma Double
k Double
th = do
Double
xi <- Gen Double
sampleXi
[Double]
us <- Int -> Gen Double -> GenT Identity [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Gen Double -> Gen Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform)
Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Gen Double) -> Double -> Gen Double
forall a b. (a -> b) -> a -> b
$ Double
th Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Double]
us
where
sampleXi :: Gen Double
sampleXi = do
(Double
xi, Double
eta) <- GenT Identity (Double, Double)
sample2
if Double
eta Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi)
then Gen Double
sampleXi
else Double -> Gen Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
xi
(Int
n, Double
delta) = (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
k, Double
k Double -> Double -> Double
forall a. Num a => a -> a -> a
- Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
ee :: Double
ee = Double -> Double
forall a. Floating a => a -> a
exp Double
1
sample2 :: GenT Identity (Double, Double)
sample2 = Double -> Double -> Double -> (Double, Double)
f (Double -> Double -> Double -> (Double, Double))
-> Gen Double
-> GenT Identity (Double -> Double -> (Double, Double))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Double
stdUniform GenT Identity (Double -> Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double -> (Double, Double))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform GenT Identity (Double -> (Double, Double))
-> Gen Double -> GenT Identity (Double, Double)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Double
stdUniform
where
f :: Double -> Double -> Double -> (Double, Double)
f Double
u Double
v Double
w
| Double
u Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
ee Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
ee Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
delta) =
let xi :: Double
xi = Double
v Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
delta)
in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1))
| Bool
otherwise =
let xi :: Double
xi = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log Double
v
in (Double
xi, Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
exp (- Double
xi))
normal :: Double
-> Double
-> Gen Double
normal :: Double -> Double -> Gen Double
normal Double
mu Double
sig = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig)
exponential :: Double
-> Gen Double
exponential :: Double -> Gen Double
exponential Double
l = (SMGen -> (Double, SMGen)) -> Gen Double
forall (m :: * -> *) a.
Monad m =>
(SMGen -> (a, SMGen)) -> GenT m a
withGen (Double -> SMGen -> (Double, SMGen)
exponentialF Double
l)
withGen :: Monad m =>
(SMGen -> (a, SMGen))
-> GenT m a
withGen :: (SMGen -> (a, SMGen)) -> GenT m a
withGen SMGen -> (a, SMGen)
f = StateT SMGen m a -> GenT m a
forall (m :: * -> *) a. StateT SMGen m a -> GenT m a
GenT (StateT SMGen m a -> GenT m a) -> StateT SMGen m a -> GenT m a
forall a b. (a -> b) -> a -> b
$ do
SMGen
gen <- StateT SMGen m SMGen
forall s (m :: * -> *). MonadState s m => m s
get
let
(a
b, SMGen
gen') = SMGen -> (a, SMGen)
f SMGen
gen
SMGen -> StateT SMGen m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put SMGen
gen'
a -> StateT SMGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b
exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF :: Double -> SMGen -> (Double, SMGen)
exponentialF Double
l SMGen
g = (Double -> Double -> Double
forall a. Floating a => a -> a -> a
exponentialICDF Double
l Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g
normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF :: Double -> Double -> SMGen -> (Double, SMGen)
normalF Double
mu Double
sig SMGen
g = (Double -> Double -> Double -> Double
forall a. InvErf a => a -> a -> a -> a
normalICDF Double
mu Double
sig Double
x, SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g
bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF :: Double -> SMGen -> (Bool, SMGen)
bernoulliF Double
p SMGen
g = (Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p , SMGen
g') where (Double
x, SMGen
g') = SMGen -> (Double, SMGen)
nextDouble SMGen
g
normalICDF :: InvErf a =>
a
-> a
-> a -> a
normalICDF :: a -> a -> a -> a
normalICDF a
mu a
sig a
p = a
mu a -> a -> a
forall a. Num a => a -> a -> a
+ a
sig a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. InvErf a => a -> a
inverf (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
exponentialICDF :: Floating a =>
a
-> a -> a
exponentialICDF :: a -> a -> a
exponentialICDF a
l a
p = (- a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
l) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
p)