{-# LANGUAGE RankNTypes, NoMonomorphismRestriction, BangPatterns #-}
{-# OPTIONS -W #-}

module Language.Hakaru.ImportanceSampler where

-- This is an interpreter that's like Interpreter except conditioning is
-- checked at run time rather than by static types.  In other words, we allow
-- models to be compiled whose conditioned parts do not match the observation
-- inputs.  In exchange, we get to make Measure an instance of Monad, and we
-- can express models whose number of observations is unknown at compile time.

import Language.Hakaru.Types
import Language.Hakaru.Mixture (Prob, empty, point, Mixture(..))
import Language.Hakaru.Sampler (Sampler, deterministic, smap, sbind)

import qualified System.Random.MWC as MWC
import Control.Monad.Primitive
import Data.Monoid
import Data.Dynamic
import System.IO.Unsafe
import qualified Data.Map.Strict as M

import qualified Data.Number.LogFloat as LF

-- Conditioned sampling
newtype Measure a = Measure { unMeasure :: [Cond] -> Sampler (a, [Cond]) }

bind :: Measure a -> (a -> Measure b) -> Measure b
bind measure continuation =
  Measure (\conds ->
    sbind (unMeasure measure conds)
          (\(a,cds) -> unMeasure (continuation a) cds))

instance Monad Measure where
  return x = Measure (\conds -> deterministic (point (x,conds) 1))
  (>>=)    = bind

updateMixture :: Typeable a => Cond -> Dist a -> Sampler a
updateMixture (Just cond) dist =
    case fromDynamic cond of
      Just y  -> deterministic (point (fromDensity y) density)
          where density = LF.logToLogFloat $ logDensity dist y
      Nothing -> error "did not get data from dynamic source"
updateMixture Nothing     dist = \g -> do e <- distSample dist g
                                          return $ point (fromDensity e) 1
    
conditioned, unconditioned :: Typeable a => Dist a -> Measure a
conditioned   dist = Measure (\(cond:conds) -> smap (\a->(a,conds))
                                               (updateMixture cond    dist))
unconditioned dist = Measure (\      conds  -> smap (\a->(a,conds))
                                               (updateMixture Nothing dist))

factor :: Prob -> Measure ()
factor p = Measure (\conds -> deterministic (point ((), conds) p))

condition :: Eq b => Measure (a, b) -> b -> Measure a
condition m b' =
    Measure (\ conds -> 
      sbind (unMeasure m conds)
            (\ ((a,b), cds) ->
                 deterministic (if b==b' then point (a,cds) 1 else empty)))

-- Drivers for testing
finish :: Mixture (a, [Cond]) -> Mixture a
finish (Mixture m) = Mixture (M.mapKeysMonotonic (\(a,[]) -> a) m)

empiricalMeasure :: (PrimMonad m, Ord a) => Int -> Measure a -> [Cond] -> m (Mixture a)
empiricalMeasure !n measure conds = do
  gen <- MWC.create
  go n gen empty
    where once = unMeasure measure conds
          go 0 _ m = return m
          go k g m = once g >>= \result -> go (k - 1) g $! mappend m (finish result)

sample :: Measure a -> [Cond] -> IO [(a, Prob)]
sample measure conds = do
  gen <- MWC.create
  unsafeInterleaveIO $ sampleNext gen
      where once = unMeasure measure conds
            mixToTuple = head . M.toList . unMixture
            sampleNext g = do
              u <- once g
              let x = mixToTuple (finish u)
              xs <- unsafeInterleaveIO $ sampleNext g
              return (x : xs)
 --  u <- once gen
 --  let x = mixToTuple (finish u)
 --  xs <- unsafeInterleaveIO $ sample measure conds gen
 --  return (x : xs)
 -- where once = unMeasure measure conds
 --       mixToTuple = head . M.toList . unMixture

logit :: Floating a => a -> a
logit !x = 1 / (1 + exp (- x))