{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}

{- | A GADT encoding of (a selection of) primitive distributions
    along with their corresponding sampling and density functions.
-}

module PrimDist (
  -- * Primitive distribution
    PrimDist(..)
  , PrimVal
  , IsPrimVal(..)
  , pattern PrimDistPrf
  , ErasedPrimDist(..)
  -- * Sampling
  , sample
  -- * Density
  , prob
  , logProb) where

import Data.Kind ( Constraint )
import Numeric.Log ( Log(..) )
import qualified Data.Map as Map
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
import qualified OpenSum
import qualified System.Random.MWC.Distributions as MWC
import Statistics.Distribution ( ContDistr(density), DiscreteDistr(probability) )
import Statistics.Distribution.Beta ( betaDistr )
import Statistics.Distribution.Binomial ( binomial )
import Statistics.Distribution.CauchyLorentz ( cauchyDistribution )
import Statistics.Distribution.Dirichlet ( dirichletDensity, dirichletDistribution )
import Statistics.Distribution.DiscreteUniform ( discreteUniformAB )
import Statistics.Distribution.Gamma ( gammaDistr )
import Statistics.Distribution.Normal ( normalDistr )
import Statistics.Distribution.Poisson ( poisson )
import Statistics.Distribution.Uniform ( uniformDistr )
import Sampler
import Util ( boolToInt )

-- | Primitive distribution
data PrimDist a where
  BernoulliDist
    :: Double           -- ^ probability of @True@
    -> PrimDist Bool
  BetaDist
    :: Double           -- ^ shape α
    -> Double           -- ^ shape β
    -> PrimDist Double
  BinomialDist
    :: Int              -- ^ number of trials
    -> Double           -- ^ probability of successful trial
    -> PrimDist Int
  CategoricalDist
    :: (Eq a, Show a, OpenSum.Member a PrimVal)
    => [(a, Double)]    -- ^ values and associated probabilities
    -> PrimDist a
  CauchyDist
    :: Double           -- ^ location
    -> Double           -- ^ scale
    -> PrimDist Double
  HalfCauchyDist
    :: Double           -- ^ scale
    -> PrimDist Double
  DeterministicDist
    :: (Eq a, Show a, OpenSum.Member a PrimVal)
    => a                -- ^ value of probability @1@
    -> PrimDist a
  DirichletDist
    :: [Double]         -- ^ concentrations
    -> PrimDist [Double]
  DiscreteDist
    :: [Double]         -- ^ list of @n@ probabilities
    -> PrimDist Int     -- ^ an index from @0@ to @n - 1@
  DiscrUniformDist
    :: Int              -- ^ lower-bound @a@
    -> Int              -- ^ upper-bound @b@
    -> PrimDist Int
  GammaDist
    :: Double           -- ^ shape k
    -> Double           -- ^ scale θ
    -> PrimDist Double
  NormalDist
    :: Double           -- ^ mean
    -> Double           -- ^ standard deviation
    -> PrimDist Double
  HalfNormalDist
    :: Double           -- ^ standard deviation
    -> PrimDist Double
  PoissonDist
    :: Double           -- ^ rate λ
    -> PrimDist Int
  UniformDist
    :: Double           -- ^ lower-bound @a@
    -> Double           -- ^ upper-bound @b@
    -> PrimDist Double

instance Eq (PrimDist a) where
  == :: PrimDist a -> PrimDist a -> Bool
(==) (NormalDist Double
m Double
s) (NormalDist Double
m' Double
s') = Double
m Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
m' Bool -> Bool -> Bool
&& Double
s Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
s'
  (==) (CauchyDist Double
m Double
s) (CauchyDist Double
m' Double
s') = Double
m Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
m' Bool -> Bool -> Bool
&& Double
s Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
s'
  (==) (HalfCauchyDist Double
s) (HalfCauchyDist Double
s') = Double
s Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
s'
  (==) (HalfNormalDist Double
s) (HalfNormalDist Double
s') = Double
s Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
s'
  (==) (BernoulliDist Double
p) (BernoulliDist Double
p') = Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
p'
  (==) (BinomialDist Int
n Double
p) (BinomialDist Int
n' Double
p') = Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n' Bool -> Bool -> Bool
&& Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
p'
  (==) (DiscreteDist [Double]
ps) (DiscreteDist [Double]
ps') = [Double]
ps [Double] -> [Double] -> Bool
forall a. Eq a => a -> a -> Bool
== [Double]
ps'
  (==) (BetaDist Double
a Double
b) (BetaDist Double
a' Double
b') = Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a' Bool -> Bool -> Bool
&& Double
b Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
b'
  (==) (GammaDist Double
a Double
b) (GammaDist Double
a' Double
b') = Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a' Bool -> Bool -> Bool
&& Double
b Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
b'
  (==) (UniformDist Double
a Double
b) (UniformDist Double
a' Double
b') = Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a' Bool -> Bool -> Bool
&& Double
b Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
b'
  (==) (DiscrUniformDist Int
min Int
max) (DiscrUniformDist Int
min' Int
max') = Int
min Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
min' Bool -> Bool -> Bool
&& Int
max Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
max'
  (==) (PoissonDist Double
l) (PoissonDist Double
l') = Double
l Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
l'
  (==) (CategoricalDist [(a, Double)]
xs) (CategoricalDist [(a, Double)]
xs') = [(a, Double)]
xs [(a, Double)] -> [(a, Double)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(a, Double)]
xs'
  (==) (DirichletDist [Double]
xs) (DirichletDist [Double]
xs')  = [Double]
xs [Double] -> [Double] -> Bool
forall a. Eq a => a -> a -> Bool
== [Double]
xs'
  (==) (DeterministicDist a
x) (DeterministicDist a
x') = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x'
  (==) PrimDist a
_ PrimDist a
_ = Bool
False

instance Show a => Show (PrimDist a) where
  show :: PrimDist a -> String
show (CauchyDist Double
mu Double
sigma) =
   String
"CauchyDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
mu String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
sigma String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (HalfCauchyDist Double
sigma) =
   String
"HalfCauchyDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
sigma String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (NormalDist Double
mu Double
sigma) =
   String
"NormalDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
mu String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
sigma String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (HalfNormalDist Double
sigma) =
   String
"HalfNormalDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
sigma String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (BernoulliDist Double
p) =
   String
"BernoulliDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
p String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (BinomialDist Int
n Double
p) =
   String
"BinomialDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
p String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++  String
")"
  show (DiscreteDist [Double]
ps) =
   String
"DiscreteDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Double] -> String
forall a. Show a => a -> String
show [Double]
ps String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (BetaDist Double
a Double
b) =
   String
"BetaDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
a String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (GammaDist Double
a Double
b) =
   String
"GammaDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
a String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (UniformDist Double
a Double
b) =
   String
"UniformDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
a String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (DiscrUniformDist Int
min Int
max) =
   String
"DiscrUniformDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
min String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
max String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (PoissonDist Double
l) =
   String
"PoissonDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
l String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (CategoricalDist [(a, Double)]
xs) =
   String
"CategoricalDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [(a, Double)] -> String
forall a. Show a => a -> String
show [(a, Double)]
xs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (DirichletDist [Double]
xs) =
   String
"DirichletDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Double] -> String
forall a. Show a => a -> String
show [Double]
xs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (DeterministicDist a
x) =
   String
"DeterministicDist(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- | An ad-hoc specification of primitive value types, for constraining the outputs of distributions
type PrimVal = '[Int, Double, [Double], Bool, String]

-- | Proof that @x@ is a primitive value
data IsPrimVal x where
  IsPrimVal :: (Show x, OpenSum.Member x PrimVal) => IsPrimVal x

-- | For pattern-matching on an arbitrary @PrimDist@ with proof that it generates a primitive value
pattern PrimDistPrf :: () => (Show x, OpenSum.Member x PrimVal) => PrimDist x -> PrimDist x
pattern $mPrimDistPrf :: forall {r} {x}.
PrimDist x
-> ((Show x, Member x PrimVal) => PrimDist x -> r)
-> (Void# -> r)
-> r
PrimDistPrf d <- d@(primDistPrf -> IsPrimVal)

-- | Proof that all primitive distributions generate a primitive value
primDistPrf :: PrimDist x -> IsPrimVal x
primDistPrf :: forall x. PrimDist x -> IsPrimVal x
primDistPrf PrimDist x
d = case PrimDist x
d of
  HalfCauchyDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  CauchyDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  NormalDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  HalfNormalDist  {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  UniformDist  {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  DiscrUniformDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  GammaDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  BetaDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  BinomialDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  BernoulliDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  CategoricalDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  DiscreteDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  PoissonDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  DirichletDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal
  DeterministicDist {} -> IsPrimVal x
forall x. (Show x, Member x PrimVal) => IsPrimVal x
IsPrimVal

-- | For erasing the types of primitive distributions
data ErasedPrimDist where
  ErasedPrimDist :: forall a. Show a => PrimDist a -> ErasedPrimDist

instance Show ErasedPrimDist where
  show :: ErasedPrimDist -> String
show (ErasedPrimDist PrimDist a
d) = PrimDist a -> String
forall a. Show a => a -> String
show PrimDist a
d

-- | Draw a value from a primitive distribution in the @Sampler@ monad
sample ::
     PrimDist a
  -> Sampler a
sample :: forall a. PrimDist a -> Sampler a
sample (HalfCauchyDist Double
σ )  =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleCauchy Double
0 Double
σ) Sampler Double -> (Double -> Sampler Double) -> Sampler Double
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> Sampler Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Sampler Double)
-> (Double -> Double) -> Double -> Sampler Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Num a => a -> a
abs
sample (CauchyDist Double
μ Double
σ )  =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleCauchy Double
μ Double
σ)
sample (HalfNormalDist Double
σ )  =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleNormal Double
0 Double
σ) Sampler Double -> (Double -> Sampler Double) -> Sampler Double
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> Sampler Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Sampler Double)
-> (Double -> Double) -> Double -> Sampler Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Num a => a -> a
abs
sample (NormalDist Double
μ Double
σ )  =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleNormal Double
μ Double
σ)
sample (UniformDist Double
min Double
max )  =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleUniform Double
min Double
max)
sample (DiscrUniformDist Int
min Int
max )  =
  (GenIO -> IO Int) -> Sampler Int
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Int -> Int -> GenIO -> IO Int
sampleDiscreteUniform Int
min Int
max)
sample (GammaDist Double
k Double
θ )        =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleGamma Double
k Double
θ)
sample (BetaDist Double
α Double
β  )         =
  (GenIO -> IO Double) -> Sampler Double
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> Double -> GenIO -> IO Double
sampleBeta Double
α Double
β)
sample (BinomialDist Int
n Double
p  )     =
  (GenIO -> IO [Bool]) -> Sampler [Bool]
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Int -> Double -> GenIO -> IO [Bool]
sampleBinomial Int
n Double
p) Sampler [Bool] -> ([Bool] -> Sampler Int) -> Sampler Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=  Int -> Sampler Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Sampler Int) -> ([Bool] -> Int) -> [Bool] -> Sampler Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  [Bool] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Bool] -> Int) -> ([Bool] -> [Bool]) -> [Bool] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> [Bool] -> [Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True)
sample (BernoulliDist Double
p )      =
  (GenIO -> IO Bool) -> Sampler Bool
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> GenIO -> IO Bool
sampleBernoulli Double
p)
sample (CategoricalDist [(a, Double)]
ps )   =
  (GenIO -> IO Int) -> Sampler Int
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Vector Double -> GenIO -> IO Int
sampleCategorical ([Double] -> Vector Double
forall a. [a] -> Vector a
V.fromList ([Double] -> Vector Double) -> [Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> Double) -> [(a, Double)] -> [Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Double) -> Double
forall a b. (a, b) -> b
snd [(a, Double)]
ps)) Sampler Int -> (Int -> Sampler a) -> Sampler a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
i -> a -> Sampler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Sampler a) -> a -> Sampler a
forall a b. (a -> b) -> a -> b
$ (a, Double) -> a
forall a b. (a, b) -> a
fst ((a, Double) -> a) -> (a, Double) -> a
forall a b. (a -> b) -> a -> b
$ [(a, Double)]
ps [(a, Double)] -> Int -> (a, Double)
forall a. [a] -> Int -> a
!! Int
i
sample (DiscreteDist [Double]
ps )      =
  (GenIO -> IO Int) -> Sampler Int
forall a. (GenIO -> IO a) -> Sampler a
createSampler ([Double] -> GenIO -> IO Int
sampleDiscrete [Double]
ps)
sample (PoissonDist Double
λ ) =
  (GenIO -> IO Int) -> Sampler Int
forall a. (GenIO -> IO a) -> Sampler a
createSampler (Double -> GenIO -> IO Int
samplePoisson Double
λ)
sample (DirichletDist [Double]
xs ) =
  (GenIO -> IO [Double]) -> Sampler [Double]
forall a. (GenIO -> IO a) -> Sampler a
createSampler ([Double] -> GenIO -> IO [Double]
sampleDirichlet [Double]
xs)
sample (DeterministicDist a
x) = a -> Sampler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

-- | Compute the density of a primitive distribution generating an observed value
prob ::
  -- | distribution
     PrimDist a
  -- | observed value
  -> a
  -- | density
  -> Double
prob :: forall a. PrimDist a -> a -> Double
prob (DirichletDist [Double]
xs) a
ys =
  let xs' :: [Double]
xs' = (Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/([Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum [Double]
xs)) [Double]
xs
  in  if [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum [Double]
xs' Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1 then String -> Double
forall a. HasCallStack => String -> a
error String
"dirichlet can't normalize" else
      case Vector Double -> Either String DirichletDistribution
dirichletDistribution ([Double] -> Vector Double
forall a. Unbox a => [a] -> Vector a
UV.fromList [Double]
xs')
      of Left String
e -> String -> Double
forall a. HasCallStack => String -> a
error String
"dirichlet error"
         Right DirichletDistribution
d -> let Exp Double
p = DirichletDistribution -> Vector Double -> Log Double
dirichletDensity DirichletDistribution
d ([Double] -> Vector Double
forall a. Unbox a => [a] -> Vector a
UV.fromList a
[Double]
ys)
                        in  Double -> Double
forall a. Floating a => a -> a
exp Double
p
prob (HalfCauchyDist Double
σ) a
y
  = if a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 then Double
0 else
            Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* CauchyDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> CauchyDistribution
cauchyDistribution Double
0 Double
σ) a
Double
y
prob (CauchyDist Double
μ Double
σ) a
y
  = CauchyDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> CauchyDistribution
cauchyDistribution Double
μ Double
σ) a
Double
y
prob (HalfNormalDist Double
σ) a
y
  = if a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 then Double
0 else
            Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> NormalDistribution
normalDistr Double
0 Double
σ) a
Double
y
prob (NormalDist Double
μ Double
σ) a
y
  = NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> NormalDistribution
normalDistr Double
μ Double
σ) a
Double
y
prob (UniformDist Double
min Double
max) a
y
  = UniformDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> UniformDistribution
uniformDistr Double
min Double
max) a
Double
y
prob (GammaDist Double
k Double
θ) a
y
  = GammaDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> GammaDistribution
gammaDistr Double
k Double
θ) a
Double
y
prob  (BetaDist Double
α Double
β) a
y
  = BetaDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density (Double -> Double -> BetaDistribution
betaDistr Double
α Double
β) a
Double
y
prob (DiscrUniformDist Int
min Int
max) a
y
  = DiscreteUniform -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability (Int -> Int -> DiscreteUniform
discreteUniformAB Int
min Int
max) a
Int
y
prob (BinomialDist Int
n Double
p) a
y
  = BinomialDistribution -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability (Int -> Double -> BinomialDistribution
binomial Int
n Double
p) a
Int
y
prob (BernoulliDist Double
p) a
i
  = BinomialDistribution -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability (Int -> Double -> BinomialDistribution
binomial Int
1 Double
p) (Bool -> Int
boolToInt a
Bool
i)
prob d :: PrimDist a
d@(CategoricalDist [(a, Double)]
ps) a
y
  = case a -> [(a, Double)] -> Maybe Double
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
y [(a, Double)]
ps of
      Maybe Double
Nothing -> String -> Double
forall a. HasCallStack => String -> a
error (String -> Double) -> String -> Double
forall a b. (a -> b) -> a -> b
$ String
"Couldn't find " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
y String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" in categorical dist"
      Just Double
p  -> Double
p
prob (DiscreteDist [Double]
ps) a
y     = [Double]
ps [Double] -> Int -> Double
forall a. [a] -> Int -> a
!! a
Int
y
prob (PoissonDist Double
λ) a
y       = PoissonDistribution -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability (Double -> PoissonDistribution
poisson Double
λ) a
Int
y
prob (DeterministicDist a
x) a
y = Double
1

-- | Compute the log density of a primitive distribution generating an observed value
logProb ::
  -- | distribution
     PrimDist a
  -- | observed value
  -> a
  -- | log density
  -> Double
logProb :: forall a. PrimDist a -> a -> Double
logProb PrimDist a
d = Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> (a -> Double) -> a -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimDist a -> a -> Double
forall a. PrimDist a -> a -> Double
prob PrimDist a
d