{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Control.Monad.Bayes.Inference.SMC2
( smc2,
SMC2,
)
where
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)
newtype SMC2 m a = SMC2 (SequentialT (TracedT (PopulationT m)) a)
deriving newtype ((forall a b. (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b. a -> SMC2 m b -> SMC2 m a) -> Functor (SMC2 m)
forall a b. a -> SMC2 m b -> SMC2 m a
forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> SMC2 m a -> SMC2 m b
fmap :: forall a b. (a -> b) -> SMC2 m a -> SMC2 m b
$c<$ :: forall (m :: * -> *) a b. Monad m => a -> SMC2 m b -> SMC2 m a
<$ :: forall a b. a -> SMC2 m b -> SMC2 m a
Functor, Functor (SMC2 m)
Functor (SMC2 m) =>
(forall a. a -> SMC2 m a)
-> (forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b)
-> (forall a b c.
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a)
-> Applicative (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (m :: * -> *). Monad m => Functor (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
pure :: forall a. a -> SMC2 m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
<*> :: forall a b. SMC2 m (a -> b) -> SMC2 m a -> SMC2 m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
liftA2 :: forall a b c. (a -> b -> c) -> SMC2 m a -> SMC2 m b -> SMC2 m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
*> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m a
<* :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m a
Applicative, Applicative (SMC2 m)
Applicative (SMC2 m) =>
(forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b)
-> (forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b)
-> (forall a. a -> SMC2 m a)
-> Monad (SMC2 m)
forall a. a -> SMC2 m a
forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *). Monad m => Applicative (SMC2 m)
forall (m :: * -> *) a. Monad m => a -> SMC2 m a
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
>>= :: forall a b. SMC2 m a -> (a -> SMC2 m b) -> SMC2 m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SMC2 m a -> SMC2 m b -> SMC2 m b
>> :: forall a b. SMC2 m a -> SMC2 m b -> SMC2 m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SMC2 m a
return :: forall a. a -> SMC2 m a
Monad)
setup :: SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup :: forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup (SMC2 SequentialT (TracedT (PopulationT m)) a
m) = SequentialT (TracedT (PopulationT m)) a
m
instance MonadTrans SMC2 where
lift :: forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
lift = SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 (SequentialT (TracedT (PopulationT m)) a -> SMC2 m a)
-> (m a -> SequentialT (TracedT (PopulationT m)) a)
-> m a
-> SMC2 m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TracedT (PopulationT m) a
-> SequentialT (TracedT (PopulationT m)) a
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TracedT (PopulationT m) a
-> SequentialT (TracedT (PopulationT m)) a)
-> (m a -> TracedT (PopulationT m) a)
-> m a
-> SequentialT (TracedT (PopulationT m)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT m a -> TracedT (PopulationT m) a
forall (m :: * -> *) a. Monad m => m a -> TracedT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (PopulationT m a -> TracedT (PopulationT m) a)
-> (m a -> PopulationT m a) -> m a -> TracedT (PopulationT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> PopulationT m a
forall (m :: * -> *) a. Monad m => m a -> PopulationT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance (MonadDistribution m) => MonadDistribution (SMC2 m) where
random :: SMC2 m Double
random = m Double -> SMC2 m Double
forall (m :: * -> *) a. Monad m => m a -> SMC2 m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
instance (Monad m) => MonadFactor (SMC2 m) where
score :: Log Double -> SMC2 m ()
score = SequentialT (TracedT (PopulationT m)) () -> SMC2 m ()
forall (m :: * -> *) a.
SequentialT (TracedT (PopulationT m)) a -> SMC2 m a
SMC2 (SequentialT (TracedT (PopulationT m)) () -> SMC2 m ())
-> (Log Double -> SequentialT (TracedT (PopulationT m)) ())
-> Log Double
-> SMC2 m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> SequentialT (TracedT (PopulationT m)) ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score
instance (MonadDistribution m) => MonadMeasure (SMC2 m)
smc2 ::
(MonadDistribution m) =>
Int ->
Int ->
Int ->
Int ->
SequentialT (TracedT (PopulationT m)) b ->
(b -> SequentialT (PopulationT (SMC2 m)) a) ->
PopulationT m [(a, Log Double)]
smc2 :: forall (m :: * -> *) b a.
MonadDistribution m =>
Int
-> Int
-> Int
-> Int
-> SequentialT (TracedT (PopulationT m)) b
-> (b -> SequentialT (PopulationT (SMC2 m)) a)
-> PopulationT m [(a, Log Double)]
smc2 Int
k Int
n Int
p Int
t SequentialT (TracedT (PopulationT m)) b
param b -> SequentialT (PopulationT (SMC2 m)) a
m =
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
-> PopulationT m [(a, Log Double)]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig
-> SMCConfig m
-> SequentialT (TracedT (PopulationT m)) a
-> PopulationT m a
rmsmc
MCMCConfig {numMCMCSteps :: Int
numMCMCSteps = Int
t, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numBurnIn :: Int
numBurnIn = Int
0}
SMCConfig {numParticles :: Int
numParticles = Int
p, numSteps :: Int
numSteps = Int
k, resampler :: forall x. PopulationT m x -> PopulationT m x
resampler = PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}
(SequentialT (TracedT (PopulationT m)) b
param SequentialT (TracedT (PopulationT m)) b
-> (b -> SequentialT (TracedT (PopulationT m)) [(a, Log Double)])
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall a b.
SequentialT (TracedT (PopulationT m)) a
-> (a -> SequentialT (TracedT (PopulationT m)) b)
-> SequentialT (TracedT (PopulationT m)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMC2 m [(a, Log Double)]
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall (m :: * -> *) a.
SMC2 m a -> SequentialT (TracedT (PopulationT m)) a
setup (SMC2 m [(a, Log Double)]
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)])
-> (b -> SMC2 m [(a, Log Double)])
-> b
-> SequentialT (TracedT (PopulationT m)) [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (SMC2 m) a -> SMC2 m [(a, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT (PopulationT (SMC2 m) a -> SMC2 m [(a, Log Double)])
-> (b -> PopulationT (SMC2 m) a) -> b -> SMC2 m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (SMC2 m)
-> SequentialT (PopulationT (SMC2 m)) a -> PopulationT (SMC2 m) a
forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smcPush (SMCConfig {numSteps :: Int
numSteps = Int
k, numParticles :: Int
numParticles = Int
n, resampler :: forall x. PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
resampler = PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
forall x. PopulationT (SMC2 m) x -> PopulationT (SMC2 m) x
forall (m :: * -> *) a.
MonadDistribution m =>
PopulationT m a -> PopulationT m a
resampleMultinomial}) (SequentialT (PopulationT (SMC2 m)) a -> PopulationT (SMC2 m) a)
-> (b -> SequentialT (PopulationT (SMC2 m)) a)
-> b
-> PopulationT (SMC2 m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> SequentialT (PopulationT (SMC2 m)) a
m)