{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.SMC2
-- Description : Sequential Monte Carlo squared (SMC²)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo squared (SMC²) sampling.
--
-- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. <https://doi.org/10.1111/j.1467-9868.2012.01046.x>
module Control.Monad.Bayes.Inference.SMC2
  ( smc2,
    SMC2,
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution (random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (SequentialT (TracedT (PopulationT m)) a)
  deriving newtype ((forall a b. (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b. a -> SMC2 m b -> SMC2 m a) -> Functor (SMC2 m)
forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
Functor, Functor (SMC2 m)
Functor (SMC2 m) =>
(forall a. a -> SMC2 m a)
-> (forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b c.
    (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a)
-> Applicative (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
pure :: forall a. a -> SMC2 m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
Applicative, Applicative (SMC2 m)
Applicative (SMC2 m) =>
(forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a. a -> SMC2 m a)
-> Monad (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
return :: forall a. a -> SMC2 m a
Monad)

setup :: SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup (SMC2 SequentialT (TracedT (PopulationT m)) a
m) = SequentialT (TracedT (PopulationT m)) a
m

instance MonadTrans SMC2 where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 (SequentialT (TracedT (PopulationT m)) a -> SMC2 m a)
-> (m a -> SequentialT (TracedT (PopulationT m)) a)
-> m a
-> SMC2 m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TracedT (PopulationT m) a
-> SequentialT (TracedT (PopulationT m)) a
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TracedT (PopulationT m) a
 -> SequentialT (TracedT (PopulationT m)) a)
-> (m a -> TracedT (PopulationT m) a)
-> m a
-> SequentialT (TracedT (PopulationT m)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT m a -> TracedT (PopulationT m) a
forall (m :: * -> *) a. Monad m => m a -> TracedT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (PopulationT m a -> TracedT (PopulationT m) a)
-> (m a -> PopulationT m a) -> m a -> TracedT (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> PopulationT m a
forall (m :: * -> *) a. Monad m => m a -> PopulationT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (MonadDistribution m) => MonadDistribution (SMC2 m) where
  random :: SMC2 m Double
random = m Double -> SMC2 m Double
forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadDistribution m => m Double
random

instance (Monad m) => MonadFactor (SMC2 m) where
  score :: Log Double -> SMC2 m ()
score = SequentialT (TracedT (PopulationT m)) () -> SMC2 m ()
forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 (SequentialT (TracedT (PopulationT m)) () -> SMC2 m ())
-> (Log Double -> SequentialT (TracedT (PopulationT m)) ())
-> Log Double
-> SMC2 m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> SequentialT (TracedT (PopulationT m)) ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance (MonadDistribution m) => MonadMeasure (SMC2 m)

-- | Sequential Monte Carlo squared.
smc2 ::
  (MonadDistribution m) =>
  -- | number of time steps
  Int ->
  -- | number of inner particles
  Int ->
  -- | number of outer particles
  Int ->
  -- | number of MH transitions
  Int ->
  -- | model parameters
  SequentialT (TracedT (PopulationT m)) b ->
  -- | model
  (b -> SequentialT (PopulationT (SMC2 m)) a) ->
  PopulationT m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> SequentialT (TracedT (PopulationT m)) b
-> (b -> SequentialT (PopulationT (SMC2 m)) a)
-> PopulationT m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t SequentialT (TracedT (PopulationT m)) b
param b -> SequentialT (PopulationT (SMC2 m)) a
m =
  MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
-> PopulationT m [(a, Log Double)]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmc
    MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
    SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. PopulationT m x -> PopulationT m x
resampler = PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}
    (SequentialT (TracedT (PopulationT m)) b
param SequentialT (TracedT (PopulationT m)) b
-> (b -> SequentialT (TracedT (PopulationT m)) [(a, Log Double)])
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall a b.
SequentialT (TracedT (PopulationT m)) a
-> (a -> SequentialT (TracedT (PopulationT m)) b)
-> SequentialT (TracedT (PopulationT m)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMC2 m [(a, Log Double)]
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup (SMC2 m [(a, Log Double)]
 -> SequentialT (TracedT (PopulationT m)) [(a, Log Double)])
-> (b -> SMC2 m [(a, Log Double)])
-> b
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (SMC2 m) a -> SMC2 m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT (PopulationT (SMC2 m) a -> SMC2 m [(a, Log Double)])
-> (b -> PopulationT (SMC2 m) a) -> b -> SMC2 m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (SMC2 m)
-> SequentialT (PopulationT (SMC2 m)) a -> PopulationT (SMC2 m) a
forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
resampler = PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
forall x. PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}) (SequentialT (PopulationT (SMC2 m)) a -> PopulationT (SMC2 m) a)
-> (b -> SequentialT (PopulationT (SMC2 m)) a)
-> b
-> PopulationT (SMC2 m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> SequentialT (PopulationT (SMC2 m)) a
m)