{-# LANGUAGE
MultiParamTypeClasses, FlexibleInstances, FlexibleContexts,
UndecidableInstances, ForeignFunctionInterface, BangPatterns,
RankNTypes
#-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
module Data.Random.Distribution.Normal
( Normal(..)
, normal, normalT
, stdNormal, stdNormalT
, doubleStdNormal
, floatStdNormal
, realFloatStdNormal
, normalTail
, normalPair
, boxMullerNormalPair
, knuthPolarNormalPair
) where
import Data.Bits
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Ziggurat
import Data.Random.RVar
import Data.Word
import Data.Vector.Generic (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
import Data.Number.Erf
import qualified System.Random.Stateful as Random
normalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
normalPair :: forall a. (Floating a, Distribution StdUniform a) => RVar (a, a)
normalPair = forall a. (Floating a, Distribution StdUniform a) => RVar (a, a)
boxMullerNormalPair
{-# INLINE boxMullerNormalPair #-}
boxMullerNormalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
boxMullerNormalPair :: forall a. (Floating a, Distribution StdUniform a) => RVar (a, a)
boxMullerNormalPair = do
a
u <- forall a. Distribution StdUniform a => RVar a
stdUniform
a
t <- forall a. Distribution StdUniform a => RVar a
stdUniform
let r :: a
r = forall a. Floating a => a -> a
sqrt (-a
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
u)
theta :: a
theta = (a
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a
pi) forall a. Num a => a -> a -> a
* a
t
x :: a
x = a
r forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
cos a
theta
y :: a
y = a
r forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sin a
theta
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,a
y)
{-# INLINE knuthPolarNormalPair #-}
knuthPolarNormalPair :: (Floating a, Ord a, Distribution Uniform a) => RVar (a,a)
knuthPolarNormalPair :: forall a.
(Floating a, Ord a, Distribution Uniform a) =>
RVar (a, a)
knuthPolarNormalPair = do
a
v1 <- forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-a
1) a
1
a
v2 <- forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-a
1) a
1
let s :: a
s = a
v1forall a. Num a => a -> a -> a
*a
v1 forall a. Num a => a -> a -> a
+ a
v2forall a. Num a => a -> a -> a
*a
v2
if a
s forall a. Ord a => a -> a -> Bool
>= a
1
then forall a.
(Floating a, Ord a, Distribution Uniform a) =>
RVar (a, a)
knuthPolarNormalPair
else forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if a
s forall a. Eq a => a -> a -> Bool
== a
0
then (a
0,a
0)
else let scale :: a
scale = forall a. Floating a => a -> a
sqrt (-a
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
s forall a. Fractional a => a -> a -> a
/ a
s)
in (a
v1 forall a. Num a => a -> a -> a
* a
scale, a
v2 forall a. Num a => a -> a -> a
* a
scale)
{-# INLINE normalTail #-}
normalTail :: (Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail :: forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail a
r = forall {m :: * -> *}. RVarT m a
go
where
go :: RVarT m a
go = do
!a
u <- forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
let !x :: a
x = forall a. Floating a => a -> a
log a
u forall a. Fractional a => a -> a -> a
/ a
r
!a
v <- forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
let !y :: a
y = forall a. Floating a => a -> a
log a
v
if a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
+ a
yforall a. Num a => a -> a -> a
+a
y forall a. Ord a => a -> a -> Bool
> a
0
then RVarT m a
go
else forall (m :: * -> *) a. Monad m => a -> m a
return (a
r forall a. Num a => a -> a -> a
- a
x)
normalZ ::
(RealFloat a, Erf a, Vector v a, Distribution Uniform a, Integral b) =>
b -> (forall m. RVarT m (Int, a)) -> Ziggurat v a
normalZ :: forall a (v :: * -> *) b.
(RealFloat a, Erf a, Vector v a, Distribution Uniform a,
Integral b) =>
b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ b
p = forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec Bool
True forall a. (Floating a, Ord a) => a -> a
normalF forall a. Floating a => a -> a
normalFInv forall a. (Floating a, Erf a, Ord a) => a -> a
normalFInt forall a. Floating a => a
normalFVol (Int
2forall a b. (Num a, Integral b) => a -> b -> a
^b
p)
normalF :: (Floating a, Ord a) => a -> a
normalF :: forall a. (Floating a, Ord a) => a -> a
normalF a
x
| a
x forall a. Ord a => a -> a -> Bool
<= a
0 = a
1
| Bool
otherwise = forall a. Floating a => a -> a
exp ((-a
0.5) forall a. Num a => a -> a -> a
* a
xforall a. Num a => a -> a -> a
*a
x)
normalFInv :: Floating a => a -> a
normalFInv :: forall a. Floating a => a -> a
normalFInv a
y = forall a. Floating a => a -> a
sqrt ((-a
2) forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
y)
normalFInt :: (Floating a, Erf a, Ord a) => a -> a
normalFInt :: forall a. (Floating a, Erf a, Ord a) => a -> a
normalFInt a
x
| a
x forall a. Ord a => a -> a -> Bool
<= a
0 = a
0
| Bool
otherwise = forall a. Floating a => a
normalFVol forall a. Num a => a -> a -> a
* forall a. Erf a => a -> a
erf (a
x forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt a
0.5)
normalFVol :: Floating a => a
normalFVol :: forall a. Floating a => a
normalFVol = forall a. Floating a => a -> a
sqrt (a
0.5 forall a. Num a => a -> a -> a
* forall a. Floating a => a
pi)
realFloatStdNormal :: (RealFloat a, Erf a, Distribution Uniform a) => RVarT m a
realFloatStdNormal :: forall a (m :: * -> *).
(RealFloat a, Erf a, Distribution Uniform a) =>
RVarT m a
realFloatStdNormal = forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat (forall a (v :: * -> *) b.
(RealFloat a, Erf a, Vector v a, Distribution Uniform a,
Integral b) =>
b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ Int
p forall a (m :: * -> *).
(Num a, Distribution Uniform a) =>
RVarT m (Int, a)
getIU forall a. a -> a -> a
`asTypeOf` (forall a. HasCallStack => a
undefined :: Ziggurat V.Vector a))
where
p :: Int
p :: Int
p = Int
6
getIU :: (Num a, Distribution Uniform a) => RVarT m (Int, a)
getIU :: forall a (m :: * -> *).
(Num a, Distribution Uniform a) =>
RVarT m (Int, a)
getIU = do
Word8
i <- forall g (m :: * -> *). StatefulGen g m => g -> m Word8
Random.uniformWord8 RGen
RGen
a
u <- forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT (-a
1) a
1
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
i forall a. Bits a => a -> a -> a
.&. (Int
2forall a b. (Num a, Integral b) => a -> b -> a
^Int
pforall a. Num a => a -> a -> a
-Int
1), a
u)
doubleStdNormal :: RVarT m Double
doubleStdNormal :: forall (m :: * -> *). RVarT m Double
doubleStdNormal = forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Double
doubleStdNormalZ
doubleStdNormalC :: Int
doubleStdNormalC :: Int
doubleStdNormalC = Int
512
doubleStdNormalR, doubleStdNormalV :: Double
doubleStdNormalR :: Double
doubleStdNormalR = Double
3.852046150368388
doubleStdNormalV :: Double
doubleStdNormalV = Double
2.4567663515413507e-3
{-# NOINLINE doubleStdNormalZ #-}
doubleStdNormalZ :: Ziggurat UV.Vector Double
doubleStdNormalZ :: Ziggurat Vector Double
doubleStdNormalZ = forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
forall a. (Floating a, Ord a) => a -> a
normalF forall a. Floating a => a -> a
normalFInv
Int
doubleStdNormalC Double
doubleStdNormalR Double
doubleStdNormalV
forall (m :: * -> *). RVarT m (Int, Double)
getIU
(forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Double
doubleStdNormalR)
where
getIU :: RVarT m (Int, Double)
getIU :: forall (m :: * -> *). RVarT m (Int, Double)
getIU = do
!Word64
w <- forall g (m :: * -> *). StatefulGen g m => g -> m Word64
Random.uniformWord64 RGen
RGen
let (Double
u,Word64
i) = Word64 -> (Double, Word64)
wordToDoubleWithExcess Word64
w
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
i forall a. Bits a => a -> a -> a
.&. (Int
doubleStdNormalCforall a. Num a => a -> a -> a
-Int
1), Double
uforall a. Num a => a -> a -> a
+Double
uforall a. Num a => a -> a -> a
-Double
1)
{-# INLINE wordToDouble #-}
wordToDouble :: Word64 -> Double
wordToDouble :: Word64 -> Double
wordToDouble Word64
x = (forall a. RealFloat a => Integer -> Int -> a
encodeFloat forall a b. (a -> b) -> a -> b
$! forall a. Integral a => a -> Integer
toInteger (Word64
x forall a. Bits a => a -> a -> a
.&. Word64
0x000fffffffffffff )) forall a b. (a -> b) -> a -> b
$ (-Int
52)
{-# INLINE wordToDoubleWithExcess #-}
wordToDoubleWithExcess :: Word64 -> (Double, Word64)
wordToDoubleWithExcess :: Word64 -> (Double, Word64)
wordToDoubleWithExcess Word64
x = (Word64 -> Double
wordToDouble Word64
x, Word64
x forall a. Bits a => a -> Int -> a
`shiftR` Int
52)
floatStdNormal :: RVarT m Float
floatStdNormal :: forall (m :: * -> *). RVarT m Float
floatStdNormal = forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Float
floatStdNormalZ
floatStdNormalC :: Int
floatStdNormalC :: Int
floatStdNormalC = Int
512
floatStdNormalR, floatStdNormalV :: Float
floatStdNormalR :: Float
floatStdNormalR = Float
3.852046150368388
floatStdNormalV :: Float
floatStdNormalV = Float
2.4567663515413507e-3
{-# NOINLINE floatStdNormalZ #-}
floatStdNormalZ :: Ziggurat UV.Vector Float
floatStdNormalZ :: Ziggurat Vector Float
floatStdNormalZ = forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
forall a. (Floating a, Ord a) => a -> a
normalF forall a. Floating a => a -> a
normalFInv
Int
floatStdNormalC Float
floatStdNormalR Float
floatStdNormalV
forall (m :: * -> *). RVarT m (Int, Float)
getIU
(forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Float
floatStdNormalR)
where
getIU :: RVarT m (Int, Float)
getIU :: forall (m :: * -> *). RVarT m (Int, Float)
getIU = do
!Word32
w <- forall g (m :: * -> *). StatefulGen g m => g -> m Word32
Random.uniformWord32 RGen
RGen
let (Float
u,Word32
i) = Word32 -> (Float, Word32)
word32ToFloatWithExcess Word32
w
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
i forall a. Bits a => a -> a -> a
.&. (Int
floatStdNormalCforall a. Num a => a -> a -> a
-Int
1), Float
uforall a. Num a => a -> a -> a
+Float
uforall a. Num a => a -> a -> a
-Float
1)
{-# INLINE word32ToFloat #-}
word32ToFloat :: Word32 -> Float
word32ToFloat :: Word32 -> Float
word32ToFloat Word32
x = (forall a. RealFloat a => Integer -> Int -> a
encodeFloat forall a b. (a -> b) -> a -> b
$! forall a. Integral a => a -> Integer
toInteger (Word32
x forall a. Bits a => a -> a -> a
.&. Word32
0x007fffff )) forall a b. (a -> b) -> a -> b
$ (-Int
23)
{-# INLINE word32ToFloatWithExcess #-}
word32ToFloatWithExcess :: Word32 -> (Float, Word32)
word32ToFloatWithExcess :: Word32 -> (Float, Word32)
word32ToFloatWithExcess Word32
x = (Word32 -> Float
word32ToFloat Word32
x, Word32
x forall a. Bits a => a -> Int -> a
`shiftR` Int
23)
normalCdf :: (Real a) => a -> a -> a -> Double
normalCdf :: forall a. Real a => a -> a -> a -> Double
normalCdf a
m a
s a
x = forall a. Erf a => a -> a
normcdf ((forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x forall a. Num a => a -> a -> a
- forall a b. (Real a, Fractional b) => a -> b
realToFrac a
m) forall a. Fractional a => a -> a -> a
/ forall a b. (Real a, Fractional b) => a -> b
realToFrac a
s)
normalPdf :: (Real a, Floating b) => a -> a -> a -> b
normalPdf :: forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
mu a
sigma a
x =
(forall a. Fractional a => a -> a
recip (forall a. Floating a => a -> a
sqrt (b
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a
pi forall a. Num a => a -> a -> a
* b
sigma2))) forall a. Num a => a -> a -> a
* (forall a. Floating a => a -> a
exp ((-((forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) forall a. Num a => a -> a -> a
- (forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) forall a. Fractional a => a -> a -> a
/ (b
2 forall a. Num a => a -> a -> a
* b
sigma2)))
where
sigma2 :: b
sigma2 = forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmaforall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
normalLogPdf :: (Real a, Floating b) => a -> a -> a -> b
normalLogPdf :: forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
mu a
sigma a
x =
forall a. Floating a => a -> a
log (forall a. Fractional a => a -> a
recip (forall a. Floating a => a -> a
sqrt (b
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a
pi forall a. Num a => a -> a -> a
* b
sigma2))) forall a. Num a => a -> a -> a
+
((-((forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) forall a. Num a => a -> a -> a
- (forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) forall a. Fractional a => a -> a -> a
/ (b
2 forall a. Num a => a -> a -> a
* b
sigma2))
where
sigma2 :: b
sigma2 = forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmaforall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
data Normal a
= StdNormal
| Normal a a
instance Distribution Normal Double where
rvarT :: forall (n :: * -> *). Normal Double -> RVarT n Double
rvarT Normal Double
StdNormal = forall (m :: * -> *). RVarT m Double
doubleStdNormal
rvarT (Normal Double
m Double
s) = do
Double
x <- forall (m :: * -> *). RVarT m Double
doubleStdNormal
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
x forall a. Num a => a -> a -> a
* Double
s forall a. Num a => a -> a -> a
+ Double
m)
instance Distribution Normal Float where
rvarT :: forall (n :: * -> *). Normal Float -> RVarT n Float
rvarT Normal Float
StdNormal = forall (m :: * -> *). RVarT m Float
floatStdNormal
rvarT (Normal Float
m Float
s) = do
Float
x <- forall (m :: * -> *). RVarT m Float
floatStdNormal
forall (m :: * -> *) a. Monad m => a -> m a
return (Float
x forall a. Num a => a -> a -> a
* Float
s forall a. Num a => a -> a -> a
+ Float
m)
instance (Real a, Distribution Normal a) => CDF Normal a where
cdf :: Normal a -> a -> Double
cdf Normal a
StdNormal = forall a. Real a => a -> a -> a -> Double
normalCdf a
0 a
1
cdf (Normal a
m a
s) = forall a. Real a => a -> a -> a -> Double
normalCdf a
m a
s
instance (Real a, Floating a, Distribution Normal a) => PDF Normal a where
pdf :: Normal a -> a -> Double
pdf Normal a
StdNormal = forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
0 a
1
pdf (Normal a
m a
s) = forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
m a
s
logPdf :: Normal a -> a -> Double
logPdf Normal a
StdNormal = forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
0 a
1
logPdf (Normal a
m a
s) = forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
m a
s
{-# SPECIALIZE stdNormal :: RVar Double #-}
{-# SPECIALIZE stdNormal :: RVar Float #-}
stdNormal :: Distribution Normal a => RVar a
stdNormal :: forall a. Distribution Normal a => RVar a
stdNormal = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar forall a. Normal a
StdNormal
stdNormalT :: Distribution Normal a => RVarT m a
stdNormalT :: forall a (m :: * -> *). Distribution Normal a => RVarT m a
stdNormalT = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT forall a. Normal a
StdNormal
normal :: Distribution Normal a => a -> a -> RVar a
normal :: forall a. Distribution Normal a => a -> a -> RVar a
normal a
m a
s = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (forall a. a -> a -> Normal a
Normal a
m a
s)
normalT :: Distribution Normal a => a -> a -> RVarT m a
normalT :: forall a (m :: * -> *).
Distribution Normal a =>
a -> a -> RVarT m a
normalT a
m a
s = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall a. a -> a -> Normal a
Normal a
m a
s)