module Language.Hakaru.Inference
( priorAsProposal
, mh
, mcmc
, gibbsProposal
, slice
, sliceX
, incompleteBeta
, regBeta
, tCDF
, approxMh
, kl
) where
import Prelude (($), (.), error, Maybe(..), return)
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Syntax.AST (Term)
import Language.Hakaru.Syntax.ABT (ABT, binder)
import Language.Hakaru.Syntax.Prelude
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Expect (expect, normalize)
import Language.Hakaru.Disintegrate (determine, density, disintegrate)
import qualified Data.Text as Text
priorAsProposal
:: (ABT Term abt, SingI a, SingI b)
=> abt '[] ('HMeasure (HPair a b))
-> abt '[] (HPair a b)
-> abt '[] ('HMeasure (HPair a b))
priorAsProposal p x =
bern (prob_ 0.5) >>= \c ->
p >>= \x' ->
dirac $
if_ c
(pair (fst x ) (snd x'))
(pair (fst x') (snd x ))
mh :: (ABT Term abt)
=> abt '[] (a ':-> 'HMeasure a)
-> abt '[] ('HMeasure a)
-> abt '[] (a ':-> 'HMeasure (HPair a 'HProb))
mh proposal target =
case determine $ density target of
Nothing -> error "mh: couldn't get density"
Just theDensity ->
let_ theDensity $ \mu ->
lam' $ \old ->
app proposal old >>= \new ->
dirac $ pair' new (mu `app` new / mu `app` old )
where lam' f = lamWithVar Text.empty (sUnMeasure $ typeOf target) f
pair' = pair_ (sUnMeasure $ typeOf target) SProb
mcmc :: (ABT Term abt)
=> abt '[] (a ':-> 'HMeasure a)
-> abt '[] ('HMeasure a)
-> abt '[] (a ':-> 'HMeasure a)
mcmc proposal target =
let_ (mh proposal target) $ \f ->
lamWithVar Text.empty (sUnMeasure $ typeOf target) $ \old ->
app f old >>= \new_ratio ->
new_ratio `unpair` \new ratio ->
bern (min (prob_ 1) ratio) >>= \accept ->
dirac (if_ accept new old)
gibbsProposal
:: (ABT Term abt, SingI a, SingI b)
=> abt '[] ('HMeasure (HPair a b))
-> abt '[] (HPair a b)
-> abt '[] ('HMeasure (HPair a b))
gibbsProposal p xy =
case determine $ disintegrate p of
Nothing -> error "gibbsProposal: couldn't disintegrate"
Just q ->
xy `unpair` \x _y ->
pair x <$> normalize (q `app` x)
slice
:: (ABT Term abt)
=> abt '[] ('HMeasure 'HReal)
-> abt '[] ('HReal ':-> 'HMeasure 'HReal)
slice target =
case determine $ density target of
Nothing -> error "slice: couldn't get density"
Just densAt ->
lam $ \x ->
uniform (real_ 0) (fromProb $ app densAt x) >>= \u ->
normalize $
lebesgue >>= \x' ->
withGuard (u < (fromProb $ app densAt x')) $
dirac x'
sliceX
:: (ABT Term abt, SingI a)
=> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure (HPair a 'HReal))
sliceX target =
case determine $ density target of
Nothing -> error "sliceX: couldn't get density"
Just densAt ->
target `bindx` \x ->
uniform (real_ 0) (fromProb $ app densAt x)
incompleteBeta
:: (ABT Term abt)
=> abt '[] 'HProb
-> abt '[] 'HProb
-> abt '[] 'HProb
-> abt '[] 'HProb
incompleteBeta x a b =
let one' = real_ 1 in
integrate (real_ 0) (fromProb x) $ \t ->
unsafeProb t ** (fromProb a one')
* unsafeProb (one' t) ** (fromProb b one')
regBeta
:: (ABT Term abt)
=> abt '[] 'HProb
-> abt '[] 'HProb
-> abt '[] 'HProb
-> abt '[] 'HProb
regBeta x a b = incompleteBeta x a b / betaFunc a b
tCDF :: (ABT Term abt) => abt '[] 'HReal -> abt '[] 'HProb -> abt '[] 'HProb
tCDF x v =
let b = regBeta (v / (unsafeProb (x*x) + v)) (v / prob_ 2) (prob_ 0.5)
in unsafeProb $ real_ 1 real_ 0.5 * fromProb b
approxMh
:: (ABT Term abt, SingI a)
=> (abt '[] a -> abt '[] ('HMeasure a))
-> abt '[] ('HMeasure a)
-> [abt '[] a -> abt '[] ('HMeasure a)]
-> abt '[] (a ':-> 'HMeasure a)
approxMh _ _ [] = error "TODO: approxMh for empty list"
approxMh proposal prior (_:xs) =
case determine . density $ bindx prior proposal of
Nothing -> error "approxMh: couldn't get density"
Just theDensity ->
lam $ \old ->
let_ theDensity $ \mu ->
unsafeProb <$> uniform (real_ 0) (real_ 1) >>= \u ->
proposal old >>= \new ->
let_ (u * mu `app` pair new old / mu `app` pair old new) $ \u0 ->
let_ (l new new / l old old) $ \l0 ->
let_ (tCDF (n real_ 1) (udif l0 u0)) $ \delta ->
if_ (delta < eps)
(if_ (u0 < l0)
(dirac new)
(dirac old))
(approxMh proposal prior xs `app` old)
where
n = real_ 2000
eps = prob_ 0.05
udif lo hi = unsafeProb $ fromProb lo fromProb hi
l = \_d1 _d2 -> prob_ 2
kl :: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
-> Maybe (abt '[] 'HProb)
kl p q = do
dp <- determine $ density p
dq <- determine $ density q
return
. expect p
. binder Text.empty (sUnMeasure $ typeOf p)
$ \i -> unsafeProb $ log (app dp i / app dq i)