{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.SMC2
-- Description : Sequential Monte Carlo squared (SMC²)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo squared (SMC²) sampling.
--
-- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. <https://doi.org/10.1111/j.1467-9868.2012.01046.x>
module Control.Monad.Bayes.Inference.SMC2
  ( smc2,
    SMC2,
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution (random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (Population, population, resampleMultinomial)
import Control.Monad.Bayes.Sequential.Coroutine (Sequential)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a)
  deriving newtype ((forall a b. (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b. a -> SMC2 m b -> SMC2 m a) -> Functor (SMC2 m)
forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
Functor, Functor (SMC2 m)
Functor (SMC2 m)
-> (forall a. a -> SMC2 m a)
-> (forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b c.
    (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a)
-> Applicative (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
pure :: forall a. a -> SMC2 m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
Applicative, Applicative (SMC2 m)
Applicative (SMC2 m)
-> (forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a. a -> SMC2 m a)
-> Monad (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
return :: forall a. a -> SMC2 m a
Monad)

setup :: SMC2 m a -> Sequential (Traced (Population m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 Sequential (Traced (Population m)) a
m) = Sequential (Traced (Population m)) a
m

instance MonadTrans SMC2 where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = Sequential (Traced (Population m)) a -> SMC2 m a
forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 (Sequential (Traced (Population m)) a -> SMC2 m a)
-> (m a -> Sequential (Traced (Population m)) a) -> m a -> SMC2 m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traced (Population m) a -> Sequential (Traced (Population m)) a
forall (m :: * -> *) a. Monad m => m a -> Sequential m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Traced (Population m) a -> Sequential (Traced (Population m)) a)
-> (m a -> Traced (Population m) a)
-> m a
-> Sequential (Traced (Population m)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Traced (Population m) a
forall (m :: * -> *) a. Monad m => m a -> Traced m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Population m a -> Traced (Population m) a)
-> (m a -> Population m a) -> m a -> Traced (Population m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Population m a
forall (m :: * -> *) a. Monad m => m a -> Population m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadDistribution m => MonadDistribution (SMC2 m) where
  random :: SMC2 m Double
random = m Double -> SMC2 m Double
forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadDistribution m => m Double
random

instance Monad m => MonadFactor (SMC2 m) where
  score :: Log Double -> SMC2 m ()
score = Sequential (Traced (Population m)) () -> SMC2 m ()
forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 (Sequential (Traced (Population m)) () -> SMC2 m ())
-> (Log Double -> Sequential (Traced (Population m)) ())
-> Log Double
-> SMC2 m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Sequential (Traced (Population m)) ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadDistribution m => MonadMeasure (SMC2 m)

-- | Sequential Monte Carlo squared.
smc2 ::
  MonadDistribution m =>
  -- | number of time steps
  Int ->
  -- | number of inner particles
  Int ->
  -- | number of outer particles
  Int ->
  -- | number of MH transitions
  Int ->
  -- | model parameters
  Sequential (Traced (Population m)) b ->
  -- | model
  (b -> Sequential (Population (SMC2 m)) a) ->
  Population m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> Sequential (Traced (Population m)) b
-> (b -> Sequential (Population (SMC2 m)) a)
-> Population m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t Sequential (Traced (Population m)) b
param b -> Sequential (Population (SMC2 m)) a
m =
  MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) [(a, Log Double)]
-> Population m [(a, Log Double)]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc
    MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
    SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. Population m x -> Population m x
resampler = Population m x -> Population m x
forall x. Population m x -> Population m x
forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}
    (Sequential (Traced (Population m)) b
param Sequential (Traced (Population m)) b
-> (b -> Sequential (Traced (Population m)) [(a, Log Double)])
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall a b.
Sequential (Traced (Population m)) a
-> (a -> Sequential (Traced (Population m)) b)
-> Sequential (Traced (Population m)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMC2 m [(a, Log Double)]
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 m [(a, Log Double)]
 -> Sequential (Traced (Population m)) [(a, Log Double)])
-> (b -> SMC2 m [(a, Log Double)])
-> b
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (SMC2 m) a -> SMC2 m [(a, Log Double)]
forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
population (Population (SMC2 m) a -> SMC2 m [(a, Log Double)])
-> (b -> Population (SMC2 m) a) -> b -> SMC2 m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (SMC2 m)
-> Sequential (Population (SMC2 m)) a -> Population (SMC2 m) a
forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. Population (SMC2 m) x -> Population (SMC2 m) x
resampler = Population (SMC2 m) x -> Population (SMC2 m) x
forall x. Population (SMC2 m) x -> Population (SMC2 m) x
forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}) (Sequential (Population (SMC2 m)) a -> Population (SMC2 m) a)
-> (b -> Sequential (Population (SMC2 m)) a)
-> b
-> Population (SMC2 m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Sequential (Population (SMC2 m)) a
m)