{-|
Module      : Control.Monad.Bayes.Traced.Dynamic
Description : Distributions on execution traces that can be dynamically frozen
Copyright   : (c) Adam Scibior, 2015-2020
License     : MIT
Maintainer  : leonhard.markert@tweag.io
Stability   : experimental
Portability : GHC

-}

module Control.Monad.Bayes.Traced.Dynamic (
  Traced,
  hoistT,
  marginal,
  freeze,
  mhStep,
  mh
) where

import Control.Monad (join)
import Control.Monad.Trans

import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Weighted as Weighted
import Control.Monad.Bayes.Free as FreeSampler

import Control.Monad.Bayes.Traced.Common

-- | A tracing monad where only a subset of random choices are traced
-- and this subset can be adjusted dynamically.
newtype Traced m a = Traced (m (Weighted (FreeSampler m) a, Trace a))
runTraced :: Traced m a -> m (Weighted (FreeSampler m) a, Trace a)
runTraced (Traced c) = c

pushM :: Monad m => m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
pushM = join . lift . lift

instance Monad m => Functor (Traced m) where
  fmap f (Traced c) = Traced $ do
    (m, t) <- c
    let m' = fmap f m
    let t' = fmap f t
    return (m', t')

instance Monad m => Applicative (Traced m) where
  pure x = Traced $ pure (pure x, pure x)
  (Traced cf) <*> (Traced cx) = Traced $ do
    (mf, tf) <- cf
    (mx, tx) <- cx
    return (mf <*> mx, tf <*> tx)

instance Monad m => Monad (Traced m) where
  (Traced cx) >>= f = Traced $ do
    (mx, tx) <- cx
    let m = mx >>= pushM . fmap fst . runTraced . f
    t <- return tx `bind` (fmap snd . runTraced . f)
    return (m, t)

instance MonadTrans Traced where
  lift m = Traced $ fmap ((,) (lift $ lift m) . pure) m

instance MonadSample m => MonadSample (Traced m) where
  random = Traced $ fmap ((,) random . singleton) random

instance MonadCond m => MonadCond (Traced m) where
  score w = Traced $ fmap (score w,) (score w >> pure (scored w))

instance MonadInfer m => MonadInfer (Traced m)

hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT f (Traced c) = Traced (f c)

marginal :: Monad m => Traced m a -> m a
marginal (Traced c) = fmap (output . snd) c

-- | Freeze all traced random choices to their current
-- values and stop tracing them.
freeze :: Monad m => Traced m a -> Traced m a
freeze (Traced c) = Traced $ do
  (_, t) <- c
  let x = output t
  return (return x, pure x)

mhStep :: MonadSample m => Traced m a -> Traced m a
mhStep (Traced c) = Traced $ do
  (m, t) <- c
  t' <- mhTrans m t
  return (m, t')

mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced c) = do
  (m,t) <- c
  let f 0 = return [t]
      f k = do
        ~(x:xs) <- f (k-1)
        y <- mhTrans m x
        return (y:x:xs)
  ts <- f n
  let xs = map output ts
  return xs