{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Control.Monad.Bayes.Inference.SMC2
( smc2,
SMC2,
)
where
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (Population, population, resampleMultinomial)
import Control.Monad.Bayes.Sequential.Coroutine (Sequential)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)
newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a)
deriving newtype ((forall a b. (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b. a -> SMC2 m b -> SMC2 m a) -> Functor (SMC2 m)
forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
Functor, Functor (SMC2 m)
Functor (SMC2 m)
-> (forall a. a -> SMC2 m a)
-> (forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b c.
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a)
-> Applicative (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
pure :: forall a. a -> SMC2 m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
Applicative, Applicative (SMC2 m)
Applicative (SMC2 m)
-> (forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a. a -> SMC2 m a)
-> Monad (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
return :: forall a. a -> SMC2 m a
Monad)
setup :: SMC2 m a -> Sequential (Traced (Population m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 Sequential (Traced (Population m)) a
m) = Sequential (Traced (Population m)) a
m
instance MonadTrans SMC2 where
lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = Sequential (Traced (Population m)) a -> SMC2 m a
forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 (Sequential (Traced (Population m)) a -> SMC2 m a)
-> (m a -> Sequential (Traced (Population m)) a) -> m a -> SMC2 m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Traced (Population m) a -> Sequential (Traced (Population m)) a
forall (m :: * -> *) a. Monad m => m a -> Sequential m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Traced (Population m) a -> Sequential (Traced (Population m)) a)
-> (m a -> Traced (Population m) a)
-> m a
-> Sequential (Traced (Population m)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population m a -> Traced (Population m) a
forall (m :: * -> *) a. Monad m => m a -> Traced m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Population m a -> Traced (Population m) a)
-> (m a -> Population m a) -> m a -> Traced (Population m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Population m a
forall (m :: * -> *) a. Monad m => m a -> Population m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance MonadDistribution m => MonadDistribution (SMC2 m) where
random :: SMC2 m Double
random = m Double -> SMC2 m Double
forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
instance Monad m => MonadFactor (SMC2 m) where
score :: Log Double -> SMC2 m ()
score = Sequential (Traced (Population m)) () -> SMC2 m ()
forall (m :: * -> *) a.
Sequential (Traced (Population m)) a -> SMC2 m a
SMC2 (Sequential (Traced (Population m)) () -> SMC2 m ())
-> (Log Double -> Sequential (Traced (Population m)) ())
-> Log Double
-> SMC2 m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Sequential (Traced (Population m)) ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score
instance MonadDistribution m => MonadMeasure (SMC2 m)
smc2 ::
MonadDistribution m =>
Int ->
Int ->
Int ->
Int ->
Sequential (Traced (Population m)) b ->
(b -> Sequential (Population (SMC2 m)) a) ->
Population m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> Sequential (Traced (Population m)) b
-> (b -> Sequential (Population (SMC2 m)) a)
-> Population m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t Sequential (Traced (Population m)) b
param b -> Sequential (Population (SMC2 m)) a
m =
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) [(a, Log Double)]
-> Population m [(a, Log Double)]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> Sequential (Traced (Population m)) a
-> Population m a
rmsmc
MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. Population m x -> Population m x
resampler = Population m x -> Population m x
forall x. Population m x -> Population m x
forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}
(Sequential (Traced (Population m)) b
param Sequential (Traced (Population m)) b
-> (b -> Sequential (Traced (Population m)) [(a, Log Double)])
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall a b.
Sequential (Traced (Population m)) a
-> (a -> Sequential (Traced (Population m)) b)
-> Sequential (Traced (Population m)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMC2 m [(a, Log Double)]
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall (m :: * -> *) a.
SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 m [(a, Log Double)]
-> Sequential (Traced (Population m)) [(a, Log Double)])
-> (b -> SMC2 m [(a, Log Double)])
-> b
-> Sequential (Traced (Population m)) [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Population (SMC2 m) a -> SMC2 m [(a, Log Double)]
forall (m :: * -> *) a. Population m a -> m [(a, Log Double)]
population (Population (SMC2 m) a -> SMC2 m [(a, Log Double)])
-> (b -> Population (SMC2 m) a) -> b -> SMC2 m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (SMC2 m)
-> Sequential (Population (SMC2 m)) a -> Population (SMC2 m) a
forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. Population (SMC2 m) x -> Population (SMC2 m) x
resampler = Population (SMC2 m) x -> Population (SMC2 m) x
forall x. Population (SMC2 m) x -> Population (SMC2 m) x
forall (m :: * -> *) a.
MonadDistribution m =>
Population m a -> Population m a
resampleMultinomial}) (Sequential (Population (SMC2 m)) a -> Population (SMC2 m) a)
-> (b -> Sequential (Population (SMC2 m)) a)
-> b
-> Population (SMC2 m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Sequential (Population (SMC2 m)) a
m)