{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Traced.Basic
( TracedT,
hoist,
marginal,
mhStep,
mh,
)
where
import Control.Applicative (Applicative (..))
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Density.Free (DensityT)
import Control.Monad.Bayes.Traced.Common
( Trace (..),
bind,
mhTrans',
scored,
singleton,
)
import Control.Monad.Bayes.Weighted (WeightedT)
import Data.Functor.Identity (Identity)
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)
import Prelude hiding (Applicative (..))
data TracedT m a = TracedT
{
forall (m :: * -> *) a.
TracedT m a -> WeightedT (DensityT Identity) a
model :: WeightedT (DensityT Identity) a,
forall (m :: * -> *) a. TracedT m a -> m (Trace a)
traceDist :: m (Trace a)
}
instance (Monad m) => Functor (TracedT m) where
fmap :: forall a b. (a -> b) -> TracedT m a -> TracedT m b
fmap a -> b
f (TracedT WeightedT (DensityT Identity) a
m m (Trace a)
d) = WeightedT (DensityT Identity) b -> m (Trace b) -> TracedT m b
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT ((a -> b)
-> WeightedT (DensityT Identity) a
-> WeightedT (DensityT Identity) b
forall a b.
(a -> b)
-> WeightedT (DensityT Identity) a
-> WeightedT (DensityT Identity) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f WeightedT (DensityT Identity) a
m) ((Trace a -> Trace b) -> m (Trace a) -> m (Trace b)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> Trace a -> Trace b
forall a b. (a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) m (Trace a)
d)
instance (Monad m) => Applicative (TracedT m) where
pure :: forall a. a -> TracedT m a
pure a
x = WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT (a -> WeightedT (DensityT Identity) a
forall a. a -> WeightedT (DensityT Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x) (Trace a -> m (Trace a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Trace a
forall a. a -> Trace a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x))
(TracedT WeightedT (DensityT Identity) (a -> b)
mf m (Trace (a -> b))
df) <*> :: forall a b. TracedT m (a -> b) -> TracedT m a -> TracedT m b
<*> (TracedT WeightedT (DensityT Identity) a
mx m (Trace a)
dx) = WeightedT (DensityT Identity) b -> m (Trace b) -> TracedT m b
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT (WeightedT (DensityT Identity) (a -> b)
mf WeightedT (DensityT Identity) (a -> b)
-> WeightedT (DensityT Identity) a
-> WeightedT (DensityT Identity) b
forall a b.
WeightedT (DensityT Identity) (a -> b)
-> WeightedT (DensityT Identity) a
-> WeightedT (DensityT Identity) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> WeightedT (DensityT Identity) a
mx) ((Trace (a -> b) -> Trace a -> Trace b)
-> m (Trace (a -> b)) -> m (Trace a) -> m (Trace b)
forall a b c. (a -> b -> c) -> m a -> m b -> m c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Trace (a -> b) -> Trace a -> Trace b
forall a b. Trace (a -> b) -> Trace a -> Trace b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) m (Trace (a -> b))
df m (Trace a)
dx)
instance (Monad m) => Monad (TracedT m) where
(TracedT WeightedT (DensityT Identity) a
mx m (Trace a)
dx) >>= :: forall a b. TracedT m a -> (a -> TracedT m b) -> TracedT m b
>>= a -> TracedT m b
f = WeightedT (DensityT Identity) b -> m (Trace b) -> TracedT m b
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT WeightedT (DensityT Identity) b
my m (Trace b)
dy
where
my :: WeightedT (DensityT Identity) b
my = WeightedT (DensityT Identity) a
mx WeightedT (DensityT Identity) a
-> (a -> WeightedT (DensityT Identity) b)
-> WeightedT (DensityT Identity) b
forall a b.
WeightedT (DensityT Identity) a
-> (a -> WeightedT (DensityT Identity) b)
-> WeightedT (DensityT Identity) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TracedT m b -> WeightedT (DensityT Identity) b
forall (m :: * -> *) a.
TracedT m a -> WeightedT (DensityT Identity) a
model (TracedT m b -> WeightedT (DensityT Identity) b)
-> (a -> TracedT m b) -> a -> WeightedT (DensityT Identity) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> TracedT m b
f
dy :: m (Trace b)
dy = m (Trace a)
dx m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
`bind` (TracedT m b -> m (Trace b)
forall (m :: * -> *) a. TracedT m a -> m (Trace a)
traceDist (TracedT m b -> m (Trace b))
-> (a -> TracedT m b) -> a -> m (Trace b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> TracedT m b
f)
instance (MonadDistribution m) => MonadDistribution (TracedT m) where
random :: TracedT m Double
random = WeightedT (DensityT Identity) Double
-> m (Trace Double) -> TracedT m Double
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT WeightedT (DensityT Identity) Double
forall (m :: * -> *). MonadDistribution m => m Double
random ((Double -> Trace Double) -> m Double -> m (Trace Double)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Trace Double
singleton m Double
forall (m :: * -> *). MonadDistribution m => m Double
random)
instance (MonadFactor m) => MonadFactor (TracedT m) where
score :: Log Double -> TracedT m ()
score Log Double
w = WeightedT (DensityT Identity) () -> m (Trace ()) -> TracedT m ()
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT (Log Double -> WeightedT (DensityT Identity) ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w) (Log Double -> m ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w m () -> m (Trace ()) -> m (Trace ())
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Trace () -> m (Trace ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Log Double -> Trace ()
scored Log Double
w))
instance (MonadMeasure m) => MonadMeasure (TracedT m)
hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a
hoist :: forall (m :: * -> *) a.
(forall x. m x -> m x) -> TracedT m a -> TracedT m a
hoist forall x. m x -> m x
f (TracedT WeightedT (DensityT Identity) a
m m (Trace a)
d) = WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT WeightedT (DensityT Identity) a
m (m (Trace a) -> m (Trace a)
forall x. m x -> m x
f m (Trace a)
d)
marginal :: (Monad m) => TracedT m a -> m a
marginal :: forall (m :: * -> *) a. Monad m => TracedT m a -> m a
marginal (TracedT WeightedT (DensityT Identity) a
_ m (Trace a)
d) = (Trace a -> a) -> m (Trace a) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Trace a -> a
forall a. Trace a -> a
output m (Trace a)
d
mhStep :: (MonadDistribution m) => TracedT m a -> TracedT m a
mhStep :: forall (m :: * -> *) a.
MonadDistribution m =>
TracedT m a -> TracedT m a
mhStep (TracedT WeightedT (DensityT Identity) a
m m (Trace a)
d) = WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
forall (m :: * -> *) a.
WeightedT (DensityT Identity) a -> m (Trace a) -> TracedT m a
TracedT WeightedT (DensityT Identity) a
m m (Trace a)
d'
where
d' :: m (Trace a)
d' = m (Trace a)
d m (Trace a) -> (Trace a -> m (Trace a)) -> m (Trace a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= WeightedT (DensityT Identity) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT Identity) a -> Trace a -> m (Trace a)
mhTrans' WeightedT (DensityT Identity) a
m
mh :: (MonadDistribution m) => Int -> TracedT m a -> m [a]
mh :: forall (m :: * -> *) a.
MonadDistribution m =>
Int -> TracedT m a -> m [a]
mh Int
n (TracedT WeightedT (DensityT Identity) a
m m (Trace a)
d) = (NonEmpty (Trace a) -> [a]) -> m (NonEmpty (Trace a)) -> m [a]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Trace a -> a) -> [Trace a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Trace a -> a
forall a. Trace a -> a
output ([Trace a] -> [a])
-> (NonEmpty (Trace a) -> [Trace a]) -> NonEmpty (Trace a) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (Trace a) -> [Trace a]
forall a. NonEmpty a -> [a]
NE.toList) (Int -> m (NonEmpty (Trace a))
forall {t}. (Ord t, Num t) => t -> m (NonEmpty (Trace a))
f Int
n)
where
f :: t -> m (NonEmpty (Trace a))
f t
k
| t
k t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = (Trace a -> NonEmpty (Trace a))
-> m (Trace a) -> m (NonEmpty (Trace a))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Trace a -> [Trace a] -> NonEmpty (Trace a)
forall a. a -> [a] -> NonEmpty a
:| []) m (Trace a)
d
| Bool
otherwise = do
(Trace a
x :| [Trace a]
xs) <- t -> m (NonEmpty (Trace a))
f (t
k t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
Trace a
y <- WeightedT (DensityT Identity) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT Identity) a -> Trace a -> m (Trace a)
mhTrans' WeightedT (DensityT Identity) a
m Trace a
x
NonEmpty (Trace a) -> m (NonEmpty (Trace a))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a
y Trace a -> [Trace a] -> NonEmpty (Trace a)
forall a. a -> [a] -> NonEmpty a
:| Trace a
x Trace a -> [Trace a] -> [Trace a]
forall a. a -> [a] -> [a]
: [Trace a]
xs)