module System.Random.MWC.Distributions
(
normal
, standard
, exponential
, truncatedExp
, gamma
, chiSquare
, beta
, categorical
, geometric0
, geometric1
, bernoulli
, dirichlet
, uniformPermutation
, uniformShuffle
) where
import Prelude hiding (mapM)
import Control.Monad (liftM,when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bits ((.&.))
import Data.Foldable (Foldable,foldl')
import Data.Traversable (Traversable,mapM)
import Data.Word (Word32)
import System.Random.MWC (Gen, uniform, uniformR)
import qualified Data.Vector.Unboxed as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
data T = T !Double !Double
normal :: PrimMonad m
=> Double
-> Double
-> Gen (PrimState m)
-> m Double
normal m s gen = do
x <- standard gen
return $! m + s * x
standard :: PrimMonad m => Gen (PrimState m) -> m Double
standard gen = loop
where
loop = do
u <- (subtract 1 . (*2)) `liftM` uniform gen
ri <- uniform gen
let i = fromIntegral ((ri :: Word32) .&. 127)
bi = I.unsafeIndex blocks i
bj = I.unsafeIndex blocks (i+1)
case () of
_| abs u < I.unsafeIndex ratios i -> return $! u * bi
| i == 0 -> normalTail (u < 0)
| otherwise -> do
let x = u * bi
xx = x * x
d = exp (0.5 * (bi * bi xx))
e = exp (0.5 * (bj * bj xx))
c <- uniform gen
if e + c * (d e) < 1
then return x
else loop
normalTail neg = tailing
where tailing = do
x <- ((/rNorm) . log) `liftM` uniform gen
y <- log `liftM` uniform gen
if y * (2) < x * x
then tailing
else return $! if neg then x rNorm else rNorm x
blocks :: I.Vector Double
blocks = (`I.snoc` 0) . I.cons (v/f) . I.cons rNorm . I.unfoldrN 126 go $! T rNorm f
where
go (T b g) = let !u = T h (exp (0.5 * h * h))
h = sqrt (2 * log (v / b + g))
in Just (h, u)
v = 9.91256303526217e-3
f = exp (0.5 * rNorm * rNorm)
rNorm :: Double
rNorm = 3.442619855899
ratios :: I.Vector Double
ratios = I.zipWith (/) (I.tail blocks) blocks
exponential :: PrimMonad m
=> Double
-> Gen (PrimState m)
-> m Double
exponential b gen = do
x <- uniform gen
return $! log x / b
truncatedExp :: PrimMonad m
=> Double
-> (Double,Double)
-> Gen (PrimState m)
-> m Double
truncatedExp scale (a,b) gen = do
let delta = b a
p <- uniform gen
return $! a log ( (1 p) + p*exp(scale*delta)) / scale
gamma :: PrimMonad m
=> Double
-> Double
-> Gen (PrimState m)
-> m Double
gamma a b gen
| a <= 0 = pkgError "gamma" "negative alpha parameter"
| otherwise = mainloop
where
mainloop = do
T x v <- innerloop
u <- uniform gen
let cont = u > 1 0.331 * sqr (sqr x)
&& log u > 0.5 * sqr x + a1 * (1 v + log v)
case () of
_| cont -> mainloop
| a >= 1 -> return $! a1 * v * b
| otherwise -> do y <- uniform gen
return $! y ** (1 / a) * a1 * v * b
innerloop = do
x <- standard gen
case 1 + a2*x of
v | v <= 0 -> innerloop
| otherwise -> return $! T x (v*v*v)
a' = if a < 1 then a + 1 else a
a1 = a' 1/3
a2 = 1 / sqrt(9 * a1)
chiSquare :: PrimMonad m
=> Int
-> Gen (PrimState m)
-> m Double
chiSquare n gen
| n <= 0 = pkgError "chiSquare" "number of degrees of freedom must be positive"
| otherwise = do x <- gamma (0.5 * fromIntegral n) 1 gen
return $! 2 * x
geometric0 :: PrimMonad m
=> Double
-> Gen (PrimState m)
-> m Int
geometric0 p gen
| p == 1 = return 0
| p > 0 && p < 1 = do q <- uniform gen
return $! floor $ log q / log (1 p)
| otherwise = pkgError "geometric0" "probability out of [0,1] range"
geometric1 :: PrimMonad m
=> Double
-> Gen (PrimState m)
-> m Int
geometric1 p gen = do n <- geometric0 p gen
return $! n + 1
beta :: PrimMonad m
=> Double
-> Double
-> Gen (PrimState m)
-> m Double
beta a b gen = do
x <- gamma a 1 gen
y <- gamma b 1 gen
return $! x / (x+y)
dirichlet :: (PrimMonad m, Traversable t)
=> t Double
-> Gen (PrimState m)
-> m (t Double)
dirichlet t gen = do
t' <- mapM (\x -> gamma x 1 gen) t
let total = foldl' (+) 0 t'
return $ fmap (/total) t'
bernoulli :: PrimMonad m
=> Double
-> Gen (PrimState m)
-> m Bool
bernoulli p gen = (<p) `liftM` uniform gen
categorical :: (PrimMonad m, G.Vector v Double)
=> v Double
-> Gen (PrimState m)
-> m Int
categorical v gen
| G.null v = pkgError "categorical" "empty weights!"
| otherwise = do
let cv = G.scanl1' (+) v
p <- (G.last cv *) `liftM` uniform gen
return $! case G.findIndex (>=p) cv of
Just i -> i
Nothing -> pkgError "categorical" "bad weights!"
uniformPermutation :: forall m v. (PrimMonad m, G.Vector v Int)
=> Int
-> Gen (PrimState m)
-> m (v Int)
uniformPermutation n gen = do
when (n<=0) (pkgError "uniformPermutation" "size must be >0")
v <- G.unsafeThaw (G.generate n id :: v Int)
let lst = n1
loop i | i == lst = G.unsafeFreeze v
| otherwise = do
j <- uniformR (i,lst) gen
M.unsafeSwap v i j
loop (i+1)
loop 0
uniformShuffle :: (PrimMonad m, G.Vector v a, G.Vector v Int)
=> v a
-> Gen (PrimState m)
-> m (v a)
uniformShuffle xs gen
| G.length xs <= 1 = return xs
| otherwise = do
idx <- uniformPermutation (G.length xs) gen
return $! G.backpermute xs idx
sqr :: Double -> Double
sqr x = x * x
pkgError :: String -> String -> a
pkgError func msg = error $ "System.Random.MWC.Distributions." ++ func ++
": " ++ msg