{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Monad.Bayes.Sampler.Lazy where
import Control.Monad (ap)
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Control.Monad.IO.Class
import Control.Monad.Identity (Identity (runIdentity))
import Control.Monad.Trans
import Numeric.Log (Log (..))
import System.Random
( RandomGen (split),
getStdGen,
newStdGen,
)
import System.Random qualified as R
data Tree = Tree
{ Tree -> Double
currentUniform :: Double,
Tree -> Trees
lazyUniforms :: Trees
}
data Trees = Trees
{ Trees -> Tree
headTree :: Tree,
Trees -> Trees
tailTrees :: Trees
}
type Sampler = SamplerT Identity
runSampler :: Sampler a -> Tree -> a
runSampler :: forall a. Sampler a -> Tree -> a
runSampler = (Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> (Tree -> Identity a) -> Tree -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tree -> Identity a) -> Tree -> a)
-> (Sampler a -> Tree -> Identity a) -> Sampler a -> Tree -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler a -> Tree -> Identity a
forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT
newtype SamplerT m a = SamplerT {forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT :: Tree -> m a}
deriving ((forall a b. (a -> b) -> SamplerT m a -> SamplerT m b)
-> (forall a b. a -> SamplerT m b -> SamplerT m a)
-> Functor (SamplerT m)
forall a b. a -> SamplerT m b -> SamplerT m a
forall a b. (a -> b) -> SamplerT m a -> SamplerT m b
forall (m :: * -> *) a b.
Functor m =>
a -> SamplerT m b -> SamplerT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT m a -> SamplerT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT m a -> SamplerT m b
fmap :: forall a b. (a -> b) -> SamplerT m a -> SamplerT m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SamplerT m b -> SamplerT m a
<$ :: forall a b. a -> SamplerT m b -> SamplerT m a
Functor)
splitTree :: Tree -> (Tree, Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Trees Tree
t Trees
ts)) = (Tree
t, Double -> Trees -> Tree
Tree Double
r Trees
ts)
randomTree :: (RandomGen g) => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a, g
g') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g in Double -> Trees -> Tree
Tree Double
a (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g')
randomTrees :: (RandomGen g) => g -> Trees
randomTrees :: forall g. RandomGen g => g -> Trees
randomTrees g
g = let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g in Tree -> Trees -> Trees
Trees (g -> Tree
forall g. RandomGen g => g -> Tree
randomTree g
g1) (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g2)
instance (Monad m) => Applicative (SamplerT m) where
pure :: forall a. a -> SamplerT m a
pure = m a -> SamplerT m a
forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SamplerT m a) -> (a -> m a) -> a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
<*> :: forall a b. SamplerT m (a -> b) -> SamplerT m a -> SamplerT m b
(<*>) = SamplerT m (a -> b) -> SamplerT m a -> SamplerT m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance (Monad m) => Monad (SamplerT m) where
return :: forall a. a -> SamplerT m a
return = a -> SamplerT m a
forall a. a -> SamplerT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(SamplerT Tree -> m a
m) >>= :: forall a b. SamplerT m a -> (a -> SamplerT m b) -> SamplerT m b
>>= a -> SamplerT m b
f = (Tree -> m b) -> SamplerT m b
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT \Tree
g -> do
let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
a
a <- Tree -> m a
m Tree
g1
let SamplerT Tree -> m b
m' = a -> SamplerT m b
f a
a
Tree -> m b
m' Tree
g2
instance MonadTrans SamplerT where
lift :: forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
lift = (Tree -> m a) -> SamplerT m a
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT ((Tree -> m a) -> SamplerT m a)
-> (m a -> Tree -> m a) -> m a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Tree -> m a
forall a b. a -> b -> a
const
instance (MonadIO m) => MonadIO (SamplerT m) where
liftIO :: forall a. IO a -> SamplerT m a
liftIO = m a -> SamplerT m a
forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SamplerT m a) -> (IO a -> m a) -> IO a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
instance (Monad m) => MonadDistribution (SamplerT m) where
random :: SamplerT m Double
random = (Tree -> m Double) -> SamplerT m Double
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT \(Tree Double
r Trees
_) -> Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
r
runSamplerTIO :: (MonadIO m) => SamplerT m a -> m a
runSamplerTIO :: forall (m :: * -> *) a. MonadIO m => SamplerT m a -> m a
runSamplerTIO SamplerT m a
m = IO StdGen -> m StdGen
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen m StdGen -> m a -> m a
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (SamplerT m a -> Tree -> m a
forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT SamplerT m a
m (Tree -> m a) -> m Tree -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree (StdGen -> Tree) -> m StdGen -> m Tree
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO StdGen -> m StdGen
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen)
independent :: (Monad m) => m a -> m [a]
independent :: forall (m :: * -> *) a. Monad m => m a -> m [a]
independent = [m a] -> m [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m a] -> m [a]) -> (m a -> [m a]) -> m a -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> [m a]
forall a. a -> [a]
repeat
weightedSamples :: (MonadIO m) => WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples :: forall (m :: * -> *) a.
MonadIO m =>
WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples = SamplerT m [(a, Log Double)] -> m [(a, Log Double)]
forall (m :: * -> *) a. MonadIO m => SamplerT m a -> m a
runSamplerTIO (SamplerT m [(a, Log Double)] -> m [(a, Log Double)])
-> (WeightedT (SamplerT m) a -> SamplerT m [(a, Log Double)])
-> WeightedT (SamplerT m) a
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SamplerT m (a, Log Double)] -> SamplerT m [(a, Log Double)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([SamplerT m (a, Log Double)] -> SamplerT m [(a, Log Double)])
-> (WeightedT (SamplerT m) a -> [SamplerT m (a, Log Double)])
-> WeightedT (SamplerT m) a
-> SamplerT m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SamplerT m (a, Log Double) -> [SamplerT m (a, Log Double)]
forall a. a -> [a]
repeat (SamplerT m (a, Log Double) -> [SamplerT m (a, Log Double)])
-> (WeightedT (SamplerT m) a -> SamplerT m (a, Log Double))
-> WeightedT (SamplerT m) a
-> [SamplerT m (a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WeightedT (SamplerT m) a -> SamplerT m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT