{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Control.Monad.Bayes.Density.Free
-- Description : Free monad transformer over random sampling
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'DensityT' is a free monad transformer over random sampling.
module Control.Monad.Bayes.Density.Free
  ( DensityT (..),
    hoist,
    interpret,
    withRandomness,
    runDensityT,
    traced,
  )
where

import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.RWS
import Control.Monad.State (evalStateT)
import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF)
import Control.Monad.Writer (WriterT (..))
import Data.Functor.Identity (Identity, runIdentity)

-- | Random sampling functor.
newtype SamF a = Random (Double -> a) deriving ((forall a b. (a -> b) -> SamF a -> SamF b)
-> (forall a b. a -> SamF b -> SamF a) -> Functor SamF
forall a b. a -> SamF b -> SamF a
forall a b. (a -> b) -> SamF a -> SamF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> SamF a -> SamF b
fmap :: forall a b. (a -> b) -> SamF a -> SamF b
$c<$ :: forall a b. a -> SamF b -> SamF a
<$ :: forall a b. a -> SamF b -> SamF a
Functor)

-- | Free monad transformer over random sampling.
--
-- Uses the Church-encoded version of the free monad for efficiency.
newtype DensityT m a = DensityT {forall (m :: * -> *) a. DensityT m a -> FT SamF m a
getDensityT :: FT SamF m a}
  deriving newtype ((forall a b. (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b. a -> DensityT m b -> DensityT m a)
-> Functor (DensityT m)
forall a b. a -> DensityT m b -> DensityT m a
forall a b. (a -> b) -> DensityT m a -> DensityT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) a b. a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b. (a -> b) -> DensityT m a -> DensityT m b
$cfmap :: forall (m :: * -> *) a b. (a -> b) -> DensityT m a -> DensityT m b
fmap :: forall a b. (a -> b) -> DensityT m a -> DensityT m b
$c<$ :: forall (m :: * -> *) a b. a -> DensityT m b -> DensityT m a
<$ :: forall a b. a -> DensityT m b -> DensityT m a
Functor, Functor (DensityT m)
Functor (DensityT m) =>
(forall a. a -> DensityT m a)
-> (forall a b.
    DensityT m (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b c.
    (a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m a)
-> Applicative (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
forall (m :: * -> *). Functor (DensityT m)
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) a. a -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall (m :: * -> *) a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
$cpure :: forall (m :: * -> *) a. a -> DensityT m a
pure :: forall a. a -> DensityT m a
$c<*> :: forall (m :: * -> *) a b.
DensityT m (a -> b) -> DensityT m a -> DensityT m b
<*> :: forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
$cliftA2 :: forall (m :: * -> *) a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
liftA2 :: forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
$c*> :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
*> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$c<* :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m a
<* :: forall a b. DensityT m a -> DensityT m b -> DensityT m a
Applicative, Applicative (DensityT m)
Applicative (DensityT m) =>
(forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a. a -> DensityT m a)
-> Monad (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
forall (m :: * -> *). Applicative (DensityT m)
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) a. a -> DensityT m a
forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
DensityT m a -> (a -> DensityT m b) -> DensityT m b
$c>>= :: forall (m :: * -> *) a b.
DensityT m a -> (a -> DensityT m b) -> DensityT m b
>>= :: forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
$c>> :: forall (m :: * -> *) a b.
DensityT m a -> DensityT m b -> DensityT m b
>> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$creturn :: forall (m :: * -> *) a. a -> DensityT m a
return :: forall a. a -> DensityT m a
Monad, (forall (m :: * -> *). Monad m => Monad (DensityT m)) =>
(forall (m :: * -> *) a. Monad m => m a -> DensityT m a)
-> MonadTrans DensityT
forall (m :: * -> *). Monad m => Monad (DensityT m)
forall (m :: * -> *) a. Monad m => m a -> DensityT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall (m :: * -> *) a. Monad m => m a -> DensityT m a
lift :: forall (m :: * -> *) a. Monad m => m a -> DensityT m a
MonadTrans)

instance MonadFree SamF (DensityT m) where
  wrap :: forall a. SamF (DensityT m a) -> DensityT m a
wrap = FT SamF m a -> DensityT m a
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT (FT SamF m a -> DensityT m a)
-> (SamF (DensityT m a) -> FT SamF m a)
-> SamF (DensityT m a)
-> DensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SamF (FT SamF m a) -> FT SamF m a
forall a. SamF (FT SamF m a) -> FT SamF m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (SamF (FT SamF m a) -> FT SamF m a)
-> (SamF (DensityT m a) -> SamF (FT SamF m a))
-> SamF (DensityT m a)
-> FT SamF m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DensityT m a -> FT SamF m a)
-> SamF (DensityT m a) -> SamF (FT SamF m a)
forall a b. (a -> b) -> SamF a -> SamF b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DensityT m a -> FT SamF m a
forall (m :: * -> *) a. DensityT m a -> FT SamF m a
getDensityT

instance (Monad m) => MonadDistribution (DensityT m) where
  random :: DensityT m Double
random = FT SamF m Double -> DensityT m Double
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT (FT SamF m Double -> DensityT m Double)
-> FT SamF m Double -> DensityT m Double
forall a b. (a -> b) -> a -> b
$ SamF Double -> FT SamF m Double
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF ((Double -> Double) -> SamF Double
forall a. (Double -> a) -> SamF a
Random Double -> Double
forall a. a -> a
id)

-- | Hoist 'DensityT' through a monad transform.
hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist forall x. m x -> n x
f (DensityT FT SamF m a
m) = FT SamF n a -> DensityT n a
forall (m :: * -> *) a. FT SamF m a -> DensityT m a
DensityT ((forall x. m x -> n x) -> FT SamF m a -> FT SamF n a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT m a -> n a
forall x. m x -> n x
f FT SamF m a
m)

-- | Execute random sampling in the transformed monad.
interpret :: (MonadDistribution m) => DensityT m a -> m a
interpret :: forall (m :: * -> *) a. MonadDistribution m => DensityT m a -> m a
interpret (DensityT FT SamF m a
m) = (SamF (m a) -> m a) -> FT SamF m a -> m a
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(f (m a) -> m a) -> FT f m a -> m a
iterT SamF (m a) -> m a
forall {m :: * -> *} {b}. MonadDistribution m => SamF (m b) -> m b
f FT SamF m a
m
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = m Double
forall (m :: * -> *). MonadDistribution m => m Double
random m Double -> (Double -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> m b
k

-- | Execute computation with supplied values for random choices.
withRandomness :: (Monad m) => [Double] -> DensityT m a -> m a
withRandomness :: forall (m :: * -> *) a. Monad m => [Double] -> DensityT m a -> m a
withRandomness [Double]
randomness (DensityT FT SamF m a
m) = StateT [Double] m a -> [Double] -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] m a) -> StateT [Double] m a)
-> FT SamF m a -> StateT [Double] m a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] m a) -> StateT [Double] m a
forall {m :: * -> *} {b}.
MonadState [Double] m =>
SamF (m b) -> m b
f FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
      [Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
      case [Double]
xs of
        [] -> [Char] -> m b
forall a. HasCallStack => [Char] -> a
error [Char]
"DensityT: the list of randomness was too short"
        Double
y : [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m b
k Double
y

-- | Execute computation with supplied values for a subset of random choices.
-- Return the output value and a record of all random choices used, whether
-- taken as input or drawn using the transformed monad.
runDensityT :: (MonadDistribution m) => [Double] -> DensityT m a -> m (a, [Double])
runDensityT :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT m a -> m (a, [Double])
runDensityT [Double]
randomness (DensityT FT SamF m a
m) =
  WriterT [Double] m a -> m (a, [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Double] m a -> m (a, [Double]))
-> WriterT [Double] m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ StateT [Double] (WriterT [Double] m) a
-> [Double] -> WriterT [Double] m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] (WriterT [Double] m) a)
 -> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] (WriterT [Double] m) a)
-> StateT [Double] (WriterT [Double] m) a
forall {m :: * -> *} {b}.
(MonadState [Double] m, MonadDistribution m,
 MonadWriter [Double] m) =>
SamF (m b) -> m b
f (FT SamF (WriterT [Double] m) a
 -> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall a b. (a -> b) -> a -> b
$ (forall a. m a -> WriterT [Double] m a)
-> FT SamF m a -> FT SamF (WriterT [Double] m) a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT m a -> WriterT [Double] m a
forall a. m a -> WriterT [Double] m a
forall (m :: * -> *) a. Monad m => m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
      -- This block runs in StateT [Double] (WriterT [Double]) m.
      -- StateT propagates consumed randomness while WriterT records
      -- randomness used, whether old or new.
      [Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
      Double
x <- case [Double]
xs of
        [] -> m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
        Double
y : [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m Double -> m Double
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m Double
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Double
y
      [Double] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Double
x]
      Double -> m b
k Double
x

-- | Like 'density', but use an arbitrary sampling monad.
traced :: (MonadDistribution m) => [Double] -> DensityT Identity a -> m (a, [Double])
traced :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT Identity a -> m (a, [Double])
traced [Double]
randomness DensityT Identity a
m = [Double] -> DensityT m a -> m (a, [Double])
forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT m a -> m (a, [Double])
runDensityT [Double]
randomness (DensityT m a -> m (a, [Double]))
-> DensityT m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ (forall x. Identity x -> m x)
-> DensityT Identity a -> DensityT m a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> DensityT m a -> DensityT n a
hoist (x -> m x
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity) DensityT Identity a
m