{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Weighted
( WeightedT,
weightedT,
extractWeight,
unweighted,
applyWeight,
hoist,
runWeightedT,
)
where
import Control.Monad.Bayes.Class
( MonadDistribution,
MonadFactor (..),
MonadMeasure,
factor,
)
import Control.Monad.State (MonadIO, MonadTrans, StateT (..), lift, mapStateT, modify)
import Numeric.Log (Log)
newtype WeightedT m a = WeightedT (StateT (Log Double) m a)
deriving newtype ((forall a b. (a -> b) -> WeightedT m a -> WeightedT m b)
-> (forall a b. a -> WeightedT m b -> WeightedT m a)
-> Functor (WeightedT m)
forall a b. a -> WeightedT m b -> WeightedT m a
forall a b. (a -> b) -> WeightedT m a -> WeightedT m b
forall (m :: * -> *) a b.
Functor m =>
a -> WeightedT m b -> WeightedT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> WeightedT m a -> WeightedT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> WeightedT m a -> WeightedT m b
fmap :: forall a b. (a -> b) -> WeightedT m a -> WeightedT m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> WeightedT m b -> WeightedT m a
<$ :: forall a b. a -> WeightedT m b -> WeightedT m a
Functor, Functor (WeightedT m)
Functor (WeightedT m) =>
(forall a. a -> WeightedT m a)
-> (forall a b.
WeightedT m (a -> b) -> WeightedT m a -> WeightedT m b)
-> (forall a b c.
(a -> b -> c) -> WeightedT m a -> WeightedT m b -> WeightedT m c)
-> (forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b)
-> (forall a b. WeightedT m a -> WeightedT m b -> WeightedT m a)
-> Applicative (WeightedT m)
forall a. a -> WeightedT m a
forall a b. WeightedT m a -> WeightedT m b -> WeightedT m a
forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b
forall a b. WeightedT m (a -> b) -> WeightedT m a -> WeightedT m b
forall a b c.
(a -> b -> c) -> WeightedT m a -> WeightedT m b -> WeightedT m c
forall (m :: * -> *). Monad m => Functor (WeightedT m)
forall (m :: * -> *) a. Monad m => a -> WeightedT m a
forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m a
forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m b
forall (m :: * -> *) a b.
Monad m =>
WeightedT m (a -> b) -> WeightedT m a -> WeightedT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> WeightedT m a -> WeightedT m b -> WeightedT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> WeightedT m a
pure :: forall a. a -> WeightedT m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
WeightedT m (a -> b) -> WeightedT m a -> WeightedT m b
<*> :: forall a b. WeightedT m (a -> b) -> WeightedT m a -> WeightedT m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> WeightedT m a -> WeightedT m b -> WeightedT m c
liftA2 :: forall a b c.
(a -> b -> c) -> WeightedT m a -> WeightedT m b -> WeightedT m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m b
*> :: forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m a
<* :: forall a b. WeightedT m a -> WeightedT m b -> WeightedT m a
Applicative, Applicative (WeightedT m)
Applicative (WeightedT m) =>
(forall a b.
WeightedT m a -> (a -> WeightedT m b) -> WeightedT m b)
-> (forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b)
-> (forall a. a -> WeightedT m a)
-> Monad (WeightedT m)
forall a. a -> WeightedT m a
forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b
forall a b. WeightedT m a -> (a -> WeightedT m b) -> WeightedT m b
forall (m :: * -> *). Monad m => Applicative (WeightedT m)
forall (m :: * -> *) a. Monad m => a -> WeightedT m a
forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m b
forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> (a -> WeightedT m b) -> WeightedT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> (a -> WeightedT m b) -> WeightedT m b
>>= :: forall a b. WeightedT m a -> (a -> WeightedT m b) -> WeightedT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
WeightedT m a -> WeightedT m b -> WeightedT m b
>> :: forall a b. WeightedT m a -> WeightedT m b -> WeightedT m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> WeightedT m a
return :: forall a. a -> WeightedT m a
Monad, Monad (WeightedT m)
Monad (WeightedT m) =>
(forall a. IO a -> WeightedT m a) -> MonadIO (WeightedT m)
forall a. IO a -> WeightedT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (WeightedT m)
forall (m :: * -> *) a. MonadIO m => IO a -> WeightedT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> WeightedT m a
liftIO :: forall a. IO a -> WeightedT m a
MonadIO, (forall (m :: * -> *). Monad m => Monad (WeightedT m)) =>
(forall (m :: * -> *) a. Monad m => m a -> WeightedT m a)
-> MonadTrans WeightedT
forall (m :: * -> *). Monad m => Monad (WeightedT m)
forall (m :: * -> *) a. Monad m => m a -> WeightedT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall (m :: * -> *) a. Monad m => m a -> WeightedT m a
lift :: forall (m :: * -> *) a. Monad m => m a -> WeightedT m a
MonadTrans, Monad (WeightedT m)
WeightedT m Double
Monad (WeightedT m) =>
WeightedT m Double
-> (Double -> Double -> WeightedT m Double)
-> (Double -> Double -> WeightedT m Double)
-> (Double -> Double -> WeightedT m Double)
-> (Double -> Double -> WeightedT m Double)
-> (Double -> WeightedT m Bool)
-> (forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m Int)
-> (forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> WeightedT m Int)
-> (forall a. [a] -> WeightedT m a)
-> (Double -> WeightedT m Int)
-> (Double -> WeightedT m Int)
-> (forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m (v Double))
-> MonadDistribution (WeightedT m)
Double -> WeightedT m Bool
Double -> WeightedT m Int
Double -> Double -> WeightedT m Double
forall a. [a] -> WeightedT m a
forall (m :: * -> *).
Monad m =>
m Double
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> m Bool)
-> (forall (v :: * -> *). Vector v Double => v Double -> m Int)
-> (forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> m Int)
-> (forall a. [a] -> m a)
-> (Double -> m Int)
-> (Double -> m Int)
-> (forall (v :: * -> *).
Vector v Double =>
v Double -> m (v Double))
-> MonadDistribution m
forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m (v Double)
forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> WeightedT m Int
forall (m :: * -> *). MonadDistribution m => Monad (WeightedT m)
forall (m :: * -> *). MonadDistribution m => WeightedT m Double
forall (m :: * -> *).
MonadDistribution m =>
Double -> WeightedT m Bool
forall (m :: * -> *).
MonadDistribution m =>
Double -> WeightedT m Int
forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> WeightedT m Double
forall (m :: * -> *) a. MonadDistribution m => [a] -> WeightedT m a
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> WeightedT m (v Double)
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> WeightedT m Int
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> WeightedT m Int
$crandom :: forall (m :: * -> *). MonadDistribution m => WeightedT m Double
random :: WeightedT m Double
$cuniform :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> WeightedT m Double
uniform :: Double -> Double -> WeightedT m Double
$cnormal :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> WeightedT m Double
normal :: Double -> Double -> WeightedT m Double
$cgamma :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> WeightedT m Double
gamma :: Double -> Double -> WeightedT m Double
$cbeta :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> WeightedT m Double
beta :: Double -> Double -> WeightedT m Double
$cbernoulli :: forall (m :: * -> *).
MonadDistribution m =>
Double -> WeightedT m Bool
bernoulli :: Double -> WeightedT m Bool
$ccategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> WeightedT m Int
categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m Int
$clogCategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> WeightedT m Int
logCategorical :: forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> WeightedT m Int
$cuniformD :: forall (m :: * -> *) a. MonadDistribution m => [a] -> WeightedT m a
uniformD :: forall a. [a] -> WeightedT m a
$cgeometric :: forall (m :: * -> *).
MonadDistribution m =>
Double -> WeightedT m Int
geometric :: Double -> WeightedT m Int
$cpoisson :: forall (m :: * -> *).
MonadDistribution m =>
Double -> WeightedT m Int
poisson :: Double -> WeightedT m Int
$cdirichlet :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> WeightedT m (v Double)
dirichlet :: forall (v :: * -> *).
Vector v Double =>
v Double -> WeightedT m (v Double)
MonadDistribution)
instance (Monad m) => MonadFactor (WeightedT m) where
score :: Log Double -> WeightedT m ()
score Log Double
w = StateT (Log Double) m () -> WeightedT m ()
forall (m :: * -> *) a. StateT (Log Double) m a -> WeightedT m a
WeightedT ((Log Double -> Log Double) -> StateT (Log Double) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w))
instance (MonadDistribution m) => MonadMeasure (WeightedT m)
runWeightedT :: WeightedT m a -> m (a, Log Double)
runWeightedT :: forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT (WeightedT StateT (Log Double) m a
m) = StateT (Log Double) m a -> Log Double -> m (a, Log Double)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Log Double) m a
m Log Double
1
unweighted :: (Functor m) => WeightedT m a -> m a
unweighted :: forall (m :: * -> *) a. Functor m => WeightedT m a -> m a
unweighted = ((a, Log Double) -> a) -> m (a, Log Double) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> a
forall a b. (a, b) -> a
fst (m (a, Log Double) -> m a)
-> (WeightedT m a -> m (a, Log Double)) -> WeightedT m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WeightedT m a -> m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT
extractWeight :: (Functor m) => WeightedT m a -> m (Log Double)
= ((a, Log Double) -> Log Double)
-> m (a, Log Double) -> m (Log Double)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd (m (a, Log Double) -> m (Log Double))
-> (WeightedT m a -> m (a, Log Double))
-> WeightedT m a
-> m (Log Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WeightedT m a -> m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT
weightedT :: (Monad m) => m (a, Log Double) -> WeightedT m a
weightedT :: forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> WeightedT m a
weightedT m (a, Log Double)
m = StateT (Log Double) m a -> WeightedT m a
forall (m :: * -> *) a. StateT (Log Double) m a -> WeightedT m a
WeightedT (StateT (Log Double) m a -> WeightedT m a)
-> StateT (Log Double) m a -> WeightedT m a
forall a b. (a -> b) -> a -> b
$ do
(a
x, Log Double
w) <- m (a, Log Double) -> StateT (Log Double) m (a, Log Double)
forall (m :: * -> *) a. Monad m => m a -> StateT (Log Double) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (a, Log Double)
m
(Log Double -> Log Double) -> StateT (Log Double) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w)
a -> StateT (Log Double) m a
forall a. a -> StateT (Log Double) m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
applyWeight :: (MonadFactor m) => WeightedT m a -> m a
applyWeight :: forall (m :: * -> *) a. MonadFactor m => WeightedT m a -> m a
applyWeight WeightedT m a
m = do
(a
x, Log Double
w) <- WeightedT m a -> m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT m a
m
Log Double -> m ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
factor Log Double
w
a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
hoist :: (forall x. m x -> n x) -> WeightedT m a -> WeightedT n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> WeightedT m a -> WeightedT n a
hoist forall x. m x -> n x
t (WeightedT StateT (Log Double) m a
m) = StateT (Log Double) n a -> WeightedT n a
forall (m :: * -> *) a. StateT (Log Double) m a -> WeightedT m a
WeightedT (StateT (Log Double) n a -> WeightedT n a)
-> StateT (Log Double) n a -> WeightedT n a
forall a b. (a -> b) -> a -> b
$ (m (a, Log Double) -> n (a, Log Double))
-> StateT (Log Double) m a -> StateT (Log Double) n a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (a, Log Double) -> n (a, Log Double)
forall x. m x -> n x
t StateT (Log Double) m a
m