{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Control.Monad.Bayes.Inference.Lazy.MH where
import Control.Monad.Bayes.Class (Log (ln))
import Control.Monad.Bayes.Sampler.Lazy
( Sampler,
Tree (..),
Trees (..),
randomTree,
runSampler,
)
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Control.Monad.Extra (iterateM)
import Control.Monad.State.Lazy (MonadState (get, put), runState)
import System.Random (RandomGen (split), getStdGen, newStdGen)
import System.Random qualified as R
mh :: forall a. Double -> WeightedT Sampler a -> IO [(a, Log Double)]
mh :: forall a. Double -> WeightedT Sampler a -> IO [(a, Log Double)]
mh Double
p WeightedT Sampler a
m = do
StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen IO StdGen -> IO StdGen -> IO StdGen
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
let (StdGen
g1, StdGen
g2) = StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
split StdGen
g
let t :: Tree
t = StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
let (a
x, Log Double
w) = Sampler (a, Log Double) -> Tree -> (a, Log Double)
forall a. Sampler a -> Tree -> a
runSampler (WeightedT Sampler a -> Sampler (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT Sampler a
m) Tree
t
let ([(Tree, a, Log Double)]
samples, StdGen
_) = State StdGen [(Tree, a, Log Double)]
-> StdGen -> ([(Tree, a, Log Double)], StdGen)
forall s a. State s a -> s -> (a, s)
runState (((Tree, a, Log Double)
-> StateT StdGen Identity (Tree, a, Log Double))
-> (Tree, a, Log Double) -> State StdGen [(Tree, a, Log Double)]
forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM (Tree, a, Log Double)
-> StateT StdGen Identity (Tree, a, Log Double)
forall {m :: * -> *} {s}.
(MonadState s m, RandomGen s) =>
(Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w)) StdGen
g2
[(a, Log Double)] -> IO [(a, Log Double)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Log Double)] -> IO [(a, Log Double)])
-> [(a, Log Double)] -> IO [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ ((Tree, a, Log Double) -> (a, Log Double))
-> [(Tree, a, Log Double)] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
_, a
x, Log Double
w) -> (a
x, Log Double
w)) [(Tree, a, Log Double)]
samples
where
step :: (Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w) = do
s
g <- m s
forall s (m :: * -> *). MonadState s m => m s
get
let (s
g1, s
g2) = s -> (s, s)
forall g. RandomGen g => g -> (g, g)
split s
g
let t' :: Tree
t' = Double -> s -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p s
g1 Tree
t
let (a
x', Log Double
w') = Sampler (a, Log Double) -> Tree -> (a, Log Double)
forall a. Sampler a -> Tree -> a
runSampler (WeightedT Sampler a -> Sampler (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT Sampler a
m) Tree
t'
let ratio :: Log Double
ratio = Log Double
w' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
w
let (Double
r, s
g2') = s -> (Double, s)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random s
g2
s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
g2'
if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
1 (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
ratio)
then (Tree, a, Log Double) -> m (Tree, a, Log Double)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t', a
x', Log Double
w')
else (Tree, a, Log Double) -> m (Tree, a, Log Double)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t, a
x, Log Double
w)
mutateTree :: forall g. (RandomGen g) => Double -> g -> Tree -> Tree
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g (Tree Double
a Trees
ts) =
let (Double
a', g
g') = (g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g :: (Double, g))
(Double
a'', g
g'') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g'
in Tree
{ currentUniform :: Double
currentUniform = if Double
a' Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p then Double
a'' else Double
a,
lazyUniforms :: Trees
lazyUniforms = Double -> g -> Trees -> Trees
forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g'' Trees
ts
}
mutateTrees :: (RandomGen g) => Double -> g -> Trees -> Trees
mutateTrees :: forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g (Trees Tree
t Trees
ts) =
let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g
in Trees
{ headTree :: Tree
headTree = Double -> g -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t,
tailTrees :: Trees
tailTrees = Double -> g -> Trees -> Trees
forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g2 Trees
ts
}