{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.MCMC
-- Description : Markov Chain Monte Carlo (MCMC)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Inference.MCMC where

import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Traced.Basic qualified as Basic
import Control.Monad.Bayes.Traced.Common
  ( MHResult (MHResult, trace),
    Trace (probDensity),
    burnIn,
    mhTransWithBool,
  )
import Control.Monad.Bayes.Traced.Dynamic qualified as Dynamic
import Control.Monad.Bayes.Traced.Static qualified as Static
import Control.Monad.Bayes.Weighted (WeightedT, unweighted)
import Pipes ((>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P

data Proposal = SingleSiteMH

data MCMCConfig = MCMCConfig {MCMCConfig -> Proposal
proposal :: Proposal, MCMCConfig -> Int
numMCMCSteps :: Int, MCMCConfig -> Int
numBurnIn :: Int}

defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig :: MCMCConfig
defaultMCMCConfig = MCMCConfig {proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = Int
1, numBurnIn :: Int
numBurnIn = Int
0}

mcmc :: (MonadDistribution m) => MCMCConfig -> Static.TracedT (WeightedT m) a -> m [a]
mcmc :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmc (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) TracedT (WeightedT m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ WeightedT m [a] -> m [a]
forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted (WeightedT m [a] -> m [a]) -> WeightedT m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> TracedT (WeightedT m) a -> WeightedT m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Static.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

mcmcBasic :: (MonadDistribution m) => MCMCConfig -> Basic.TracedT (WeightedT m) a -> m [a]
mcmcBasic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmcBasic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) TracedT (WeightedT m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ WeightedT m [a] -> m [a]
forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted (WeightedT m [a] -> m [a]) -> WeightedT m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> TracedT (WeightedT m) a -> WeightedT m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Basic.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

mcmcDynamic :: (MonadDistribution m) => MCMCConfig -> Dynamic.TracedT (WeightedT m) a -> m [a]
mcmcDynamic :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmcDynamic (MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..}) TracedT (WeightedT m) a
m = Int -> m [a] -> m [a]
forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
numBurnIn (m [a] -> m [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ WeightedT m [a] -> m [a]
forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted (WeightedT m [a] -> m [a]) -> WeightedT m [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> TracedT (WeightedT m) a -> WeightedT m [a]
forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
Dynamic.mh Int
numMCMCSteps TracedT (WeightedT m) a
m

-- -- | draw iid samples until you get one that has non-zero likelihood
independentSamples :: (Monad m) => Static.TracedT m a -> P.Producer (MHResult a) m (Trace a)
independentSamples :: forall (m :: * -> *) a.
Monad m =>
TracedT m a -> Producer (MHResult a) m (Trace a)
independentSamples (Static.TracedT WeightedT (DensityT m) a
_w m (Trace a)
d) =
  m (Trace a) -> Proxy X () () (Trace a) m (Trace a)
forall (m :: * -> *) a x' x r.
Monad m =>
m a -> Proxy x' x () a m r
P.repeatM m (Trace a)
d
    Proxy X () () (Trace a) m (Trace a)
-> Proxy () (Trace a) () (Trace a) m (Trace a)
-> Proxy X () () (Trace a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Trace a -> Bool) -> Proxy () (Trace a) () (Trace a) m (Trace a)
forall (m :: * -> *) a. Functor m => (a -> Bool) -> Pipe a a m a
P.takeWhile' ((Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
== Log Double
0) (Log Double -> Bool) -> (Trace a -> Log Double) -> Trace a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> Log Double
forall a. Trace a -> Log Double
probDensity)
    Proxy X () () (Trace a) m (Trace a)
-> Proxy () (Trace a) () (MHResult a) m (Trace a)
-> Proxy X () () (MHResult a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Trace a -> MHResult a)
-> Proxy () (Trace a) () (MHResult a) m (Trace a)
forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map (Bool -> Trace a -> MHResult a
forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False)

-- | convert a probabilistic program into a producer of samples
mcmcP :: (MonadDistribution m) => MCMCConfig -> Static.TracedT m a -> P.Producer (MHResult a) m ()
mcmcP :: forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {Int
Proposal
proposal :: MCMCConfig -> Proposal
numMCMCSteps :: MCMCConfig -> Int
numBurnIn :: MCMCConfig -> Int
proposal :: Proposal
numMCMCSteps :: Int
numBurnIn :: Int
..} m :: TracedT m a
m@(Static.TracedT WeightedT (DensityT m) a
w m (Trace a)
_) = do
  Trace a
initialValue <- TracedT m a -> Producer (MHResult a) m (Trace a)
forall (m :: * -> *) a.
Monad m =>
TracedT m a -> Producer (MHResult a) m (Trace a)
independentSamples TracedT m a
m Producer (MHResult a) m (Trace a)
-> Proxy () (MHResult a) () (MHResult a) m (Trace a)
-> Producer (MHResult a) m (Trace a)
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (MHResult a) () (MHResult a) m (Trace a)
Consumer' (MHResult a) m (Trace a)
forall (m :: * -> *) a r. Functor m => Consumer' a m r
P.drain
  ( (Trace a -> m (Either () (MHResult a, Trace a)))
-> Trace a -> Producer (MHResult a) m ()
forall (m :: * -> *) s r a.
Monad m =>
(s -> m (Either r (a, s))) -> s -> Producer a m r
P.unfoldr ((MHResult a -> Either () (MHResult a, Trace a))
-> m (MHResult a) -> m (Either () (MHResult a, Trace a))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((MHResult a, Trace a) -> Either () (MHResult a, Trace a)
forall a b. b -> Either a b
Right ((MHResult a, Trace a) -> Either () (MHResult a, Trace a))
-> (MHResult a -> (MHResult a, Trace a))
-> MHResult a
-> Either () (MHResult a, Trace a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\MHResult a
k -> (MHResult a
k, MHResult a -> Trace a
forall a. MHResult a -> Trace a
trace MHResult a
k))) (m (MHResult a) -> m (Either () (MHResult a, Trace a)))
-> (Trace a -> m (MHResult a))
-> Trace a
-> m (Either () (MHResult a, Trace a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
mhTransWithBool WeightedT (DensityT m) a
w) Trace a
initialValue
      Producer (MHResult a) m ()
-> Proxy () (MHResult a) () (MHResult a) m ()
-> Producer (MHResult a) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (MHResult a) () (MHResult a) m ()
forall (m :: * -> *) a r. Functor m => Int -> Pipe a a m r
P.drop Int
numBurnIn
    )