{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Density.Free
( DensityT (..),
hoist,
interpret,
withRandomness,
runDensityT,
traced,
)
where
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.RWS
import Control.Monad.State (evalStateT)
import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF)
import Control.Monad.Writer (WriterT (..))
import Data.Functor.Identity (Identity, runIdentity)
newtype SamF a = Random (Double -> a) deriving ((forall a b. (a -> b) -> SamF a -> SamF b)
-> (forall a b. a -> SamF b -> SamF a) -> Functor SamF
forall a b. a -> SamF b -> SamF a
forall a b. (a -> b) -> SamF a -> SamF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> SamF a -> SamF b
fmap :: forall a b. (a -> b) -> SamF a -> SamF b
$c<$ :: forall a b. a -> SamF b -> SamF a
<$ :: forall a b. a -> SamF b -> SamF a
Functor)
newtype DensityT m a = DensityT {forall (m :: * -> *) a. DensityT m a -> FT SamF m a
getDensityT :: FT SamF m a}
deriving newtype ((forall a b. (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b. a -> DensityT m b -> DensityT m a)
-> Functor (DensityT m)
forall a b. a -> DensityT m b -> DensityT m a
forall a b. (a -> b) -> DensityT m a -> DensityT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) a b. a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b. (a -> b) -> DensityT m a -> DensityT m b
$cfmap :: forall (m :: * -> *) a b. (a -> b) -> DensityT m a -> DensityT m b
fmap :: forall a b. (a -> b) -> DensityT m a -> DensityT m b
$c<$ :: forall (m :: * -> *) a b. a -> DensityT m b -> DensityT m a
<$ :: forall a b. a -> DensityT m b -> DensityT m a
Functor, Functor (DensityT m)
Functor (DensityT m) =>
(forall a. a -> DensityT m a)
-> (forall a b.
DensityT m (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m a)
-> Applicative (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
forall (m :: * -> *). Functor (DensityT m)
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) a. a -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall (m :: * -> *) a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
$cpure :: forall (m :: * -> *) a. a -> DensityT m a
pure :: forall a. a -> DensityT m a
$c<*> :: forall (m :: * -> *) a b.
DensityT m (a -> b) -> DensityT m a -> DensityT m b
<*> :: forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
$cliftA2 :: forall (m :: * -> *) a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
liftA2 :: forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
$c*> :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
*> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$c<* :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m a
<* :: forall a b. DensityT m a -> DensityT m b -> DensityT m a
Applicative, Applicative (DensityT m)
Applicative (DensityT m) =>
(forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a. a -> DensityT m a)
-> Monad (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
forall (m :: * -> *). Applicative (DensityT m)
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) a. a -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
DensityT m a -> (a -> DensityT m b) -> DensityT m b
$c>>= :: forall (m :: * -> *) a b.
DensityT m a -> (a -> DensityT m b) -> DensityT m b
>>= :: forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
$c>> :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
>> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$creturn :: forall (m :: * -> *) a. a -> DensityT m a
return :: forall a. a -> DensityT m a
Monad, (forall (m :: * -> *). Monad m => Monad (DensityT m)) =>
(forall (m :: * -> *) a. Monad m => m a -> DensityT m a)
-> MonadTrans DensityT
forall (m :: * -> *). Monad m => Monad (DensityT m)
forall (m :: * -> *) a. Monad m => m a -> DensityT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall (m :: * -> *) a. Monad m => m a -> DensityT m a
lift :: forall (m :: * -> *) a. Monad m => m a -> DensityT m a
MonadTrans)
instance MonadFree SamF (DensityT m) where
wrap :: forall a. SamF (DensityT m a) -> DensityT m a
wrap = FT SamF m a -> DensityT m a
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT (FT SamF m a -> DensityT m a)
-> (SamF (DensityT m a) -> FT SamF m a)
-> SamF (DensityT m a)
-> DensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SamF (FT SamF m a) -> FT SamF m a
forall a. SamF (FT SamF m a) -> FT SamF m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (SamF (FT SamF m a) -> FT SamF m a)
-> (SamF (DensityT m a) -> SamF (FT SamF m a))
-> SamF (DensityT m a)
-> FT SamF m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DensityT m a -> FT SamF m a)
-> SamF (DensityT m a) -> SamF (FT SamF m a)
forall a b. (a -> b) -> SamF a -> SamF b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DensityT m a -> FT SamF m a
forall (m :: * -> *) a. DensityT m a -> FT SamF m a
getDensityT
instance (Monad m) => MonadDistribution (DensityT m) where
random :: DensityT m Double
random = FT SamF m Double -> DensityT m Double
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT (FT SamF m Double -> DensityT m Double)
-> FT SamF m Double -> DensityT m Double
forall a b. (a -> b) -> a -> b
$ SamF Double -> FT SamF m Double
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF ((Double -> Double) -> SamF Double
forall a. (Double -> a) -> SamF a
Random Double -> Double
forall a. a -> a
id)
hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist forall x. m x -> n x
f (DensityT FT SamF m a
m) = FT SamF n a -> DensityT n a
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT ((forall x. m x -> n x) -> FT SamF m a -> FT SamF n a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT m a -> n a
forall x. m x -> n x
f FT SamF m a
m)
interpret :: (MonadDistribution m) => DensityT m a -> m a
interpret :: forall (m :: * -> *) a. MonadDistribution m => DensityT m a -> m a
interpret (DensityT FT SamF m a
m) = (SamF (m a) -> m a) -> FT SamF m a -> m a
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(f (m a) -> m a) -> FT f m a -> m a
iterT SamF (m a) -> m a
forall {m :: * -> *} {b}. MonadDistribution m => SamF (m b) -> m b
f FT SamF m a
m
where
f :: SamF (m b) -> m b
f (Random Double -> m b
k) = m Double
forall (m :: * -> *). MonadDistribution m => m Double
random m Double -> (Double -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> m b
k
withRandomness :: (Monad m) => [Double] -> DensityT m a -> m a
withRandomness :: forall (m :: * -> *) a. Monad m => [Double] -> DensityT m a -> m a
withRandomness [Double]
randomness (DensityT FT SamF m a
m) = StateT [Double] m a -> [Double] -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] m a) -> StateT [Double] m a)
-> FT SamF m a -> StateT [Double] m a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] m a) -> StateT [Double] m a
forall {m :: * -> *} {b}.
MonadState [Double] m =>
SamF (m b) -> m b
f FT SamF m a
m) [Double]
randomness
where
f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
[Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
case [Double]
xs of
[] -> [Char] -> m b
forall a. HasCallStack => [Char] -> a
error [Char]
"DensityT: the list of randomness was too short"
Double
y : [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m b
k Double
y
runDensityT :: (MonadDistribution m) => [Double] -> DensityT m a -> m (a, [Double])
runDensityT :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT m a -> m (a, [Double])
runDensityT [Double]
randomness (DensityT FT SamF m a
m) =
WriterT [Double] m a -> m (a, [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Double] m a -> m (a, [Double]))
-> WriterT [Double] m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ StateT [Double] (WriterT [Double] m) a
-> [Double] -> WriterT [Double] m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] (WriterT [Double] m) a)
-> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] (WriterT [Double] m) a)
-> StateT [Double] (WriterT [Double] m) a
forall {m :: * -> *} {b}.
(MonadState [Double] m, MonadDistribution m,
MonadWriter [Double] m) =>
SamF (m b) -> m b
f (FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall a b. (a -> b) -> a -> b
$ (forall a. m a -> WriterT [Double] m a)
-> FT SamF m a -> FT SamF (WriterT [Double] m) a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT m a -> WriterT [Double] m a
forall a. m a -> WriterT [Double] m a
forall (m :: * -> *) a. Monad m => m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift FT SamF m a
m) [Double]
randomness
where
f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
[Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
Double
x <- case [Double]
xs of
[] -> m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
Double
y : [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m Double -> m Double
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m Double
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Double
y
[Double] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Double
x]
Double -> m b
k Double
x
traced :: (MonadDistribution m) => [Double] -> DensityT Identity a -> m (a, [Double])
traced :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT Identity a -> m (a, [Double])
traced [Double]
randomness DensityT Identity a
m = [Double] -> DensityT m a -> m (a, [Double])
forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT m a -> m (a, [Double])
runDensityT [Double]
randomness (DensityT m a -> m (a, [Double]))
-> DensityT m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ (forall x. Identity x -> m x)
-> DensityT Identity a -> DensityT m a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist (x -> m x
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity) DensityT Identity a
m