{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}

module Control.Monad.Bayes.Inference.Lazy.MH where

import Control.Monad.Bayes.Class (Log (ln))
import Control.Monad.Bayes.Sampler.Lazy
  ( Sampler,
    Tree (..),
    Trees (..),
    randomTree,
    runSampler,
  )
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Control.Monad.Extra (iterateM)
import Control.Monad.State.Lazy (MonadState (get, put), runState)
import System.Random (RandomGen (split), getStdGen, newStdGen)
import System.Random qualified as R

mh :: forall a. Double -> WeightedT Sampler a -> IO [(a, Log Double)]
mh :: forall a. Double -> WeightedT Sampler a -> IO [(a, Log Double)]
mh Double
p WeightedT Sampler a
m = do
  -- Top level: produce a stream of samples.
  -- Split the random number generator in two
  -- One part is used as the first seed for the simulation,
  -- and one part is used for the randomness in the MH algorithm.
  StdGen
g <- IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen IO StdGen -> IO StdGen -> IO StdGen
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen
  let (StdGen
g1, StdGen
g2) = StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
split StdGen
g
  let t :: Tree
t = StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree StdGen
g1
  let (a
x, Log Double
w) = Sampler (a, Log Double) -> Tree -> (a, Log Double)
forall a. Sampler a -> Tree -> a
runSampler (WeightedT Sampler a -> Sampler (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT Sampler a
m) Tree
t
  -- Now run step over and over to get a stream of (tree,result,weight)s.
  let ([(Tree, a, Log Double)]
samples, StdGen
_) = State StdGen [(Tree, a, Log Double)]
-> StdGen -> ([(Tree, a, Log Double)], StdGen)
forall s a. State s a -> s -> (a, s)
runState (((Tree, a, Log Double)
 -> StateT StdGen Identity (Tree, a, Log Double))
-> (Tree, a, Log Double) -> State StdGen [(Tree, a, Log Double)]
forall (m :: * -> *) a. Monad m => (a -> m a) -> a -> m [a]
iterateM (Tree, a, Log Double)
-> StateT StdGen Identity (Tree, a, Log Double)
forall {m :: * -> *} {s}.
(MonadState s m, RandomGen s) =>
(Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w)) StdGen
g2
  -- The stream of seeds is used to produce a stream of result/weight pairs.
  [(a, Log Double)] -> IO [(a, Log Double)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(a, Log Double)] -> IO [(a, Log Double)])
-> [(a, Log Double)] -> IO [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ ((Tree, a, Log Double) -> (a, Log Double))
-> [(Tree, a, Log Double)] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Tree
_, a
x, Log Double
w) -> (a
x, Log Double
w)) [(Tree, a, Log Double)]
samples
  where
    --   where
    {- NB There are three kinds of randomness in the step function.
    1. The start tree 't', which is the source of randomness for simulating the
    program m to start with. This is sort-of the point in the "state space".
    2. The randomness needed to propose a new tree ('g1')
    3. The randomness needed to decide whether to accept or reject that ('g2')
    The tree t is an argument and result,
    but we use a state monad ('get'/'put') to deal with the other randomness '(g,g1,g2)' -}

    -- step :: RandomGen g => (Tree, a, Log Double) -> State g (Tree, a, Log Double)
    step :: (Tree, a, Log Double) -> m (Tree, a, Log Double)
step (Tree
t, a
x, Log Double
w) = do
      -- Randomly change some sites
      s
g <- m s
forall s (m :: * -> *). MonadState s m => m s
get
      let (s
g1, s
g2) = s -> (s, s)
forall g. RandomGen g => g -> (g, g)
split s
g
      let t' :: Tree
t' = Double -> s -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p s
g1 Tree
t
      -- Rerun the model with the new tree, to get a new
      -- weight w'.
      let (a
x', Log Double
w') = Sampler (a, Log Double) -> Tree -> (a, Log Double)
forall a. Sampler a -> Tree -> a
runSampler (WeightedT Sampler a -> Sampler (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT Sampler a
m) Tree
t'
      -- MH acceptance ratio. This is the probability of either
      -- returning the new seed or the old one.
      let ratio :: Log Double
ratio = Log Double
w' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
w
      let (Double
r, s
g2') = s -> (Double, s)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random s
g2
      s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
g2'
      if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
1 (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
ratio)
        then (Tree, a, Log Double) -> m (Tree, a, Log Double)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t', a
x', Log Double
w')
        else (Tree, a, Log Double) -> m (Tree, a, Log Double)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree
t, a
x, Log Double
w)

-- Replace the labels of a tree randomly, with probability p
mutateTree :: forall g. (RandomGen g) => Double -> g -> Tree -> Tree
mutateTree :: forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g (Tree Double
a Trees
ts) =
  let (Double
a', g
g') = (g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g :: (Double, g))
      (Double
a'', g
g'') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g'
   in Tree
        { currentUniform :: Double
currentUniform = if Double
a' Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p then Double
a'' else Double
a,
          lazyUniforms :: Trees
lazyUniforms = Double -> g -> Trees -> Trees
forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g'' Trees
ts
        }

mutateTrees :: (RandomGen g) => Double -> g -> Trees -> Trees
mutateTrees :: forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g (Trees Tree
t Trees
ts) =
  let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g
   in Trees
        { headTree :: Tree
headTree = Double -> g -> Tree -> Tree
forall g. RandomGen g => Double -> g -> Tree -> Tree
mutateTree Double
p g
g1 Tree
t,
          tailTrees :: Trees
tailTrees = Double -> g -> Trees -> Trees
forall g. RandomGen g => Double -> g -> Trees -> Trees
mutateTrees Double
p g
g2 Trees
ts
        }