{-| Module : Control.Monad.Bayes.Population Description : Representation of distributions using multiple samples Copyright : (c) Adam Scibior, 2015-2020 License : MIT Maintainer : leonhard.markert@tweag.io Stability : experimental Portability : GHC 'Population' turns a single sample into a collection of weighted samples. -} module Control.Monad.Bayes.Population ( Population, runPopulation, explicitPopulation, fromWeightedList, spawn, resampleMultinomial, resampleSystematic, extractEvidence, pushEvidence, proper, evidence, collapse, mapPopulation, normalize, popAvg, flatten, hoist ) where import Prelude hiding (sum, all) import Control.Arrow (second) import Control.Monad.Trans import Control.Monad.Trans.List import Control.Monad (replicateM) import qualified Data.List import qualified Data.Vector as V import Numeric.Log import Control.Monad.Bayes.Class import Control.Monad.Bayes.Weighted hiding (flatten, hoist) -- | A collection of weighted samples, or particles. newtype Population m a = Population (Weighted (ListT m) a) deriving(Functor,Applicative,Monad,MonadIO,MonadSample,MonadCond,MonadInfer) instance MonadTrans Population where lift = Population . lift . lift -- | Explicit representation of the weighted sample with weights in the log -- domain. runPopulation :: Functor m => Population m a -> m [(a, Log Double)] runPopulation (Population m) = runListT $ runWeighted m -- | Explicit representation of the weighted sample. explicitPopulation :: Functor m => Population m a -> m [(a, Double)] explicitPopulation = fmap (map (second (exp . ln))) . runPopulation -- | Initialize 'Population' with a concrete weighted sample. fromWeightedList :: Monad m => m [(a,Log Double)] -> Population m a fromWeightedList = Population . withWeight . ListT -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. -- It is therefore safe to use 'spawn' in arbitrary places in the program -- without introducing bias. spawn :: Monad m => Int -> Population m () spawn n = fromWeightedList $ pure $ replicate n ((), 1 / fromIntegral n) resampleGeneric :: MonadSample m => (V.Vector Double -> m [Int]) -- ^ resampler -> Population m a -> Population m a resampleGeneric resampler m = fromWeightedList $ do pop <- runPopulation m let (xs, ps) = unzip pop let n = length xs let z = sum ps if z > 0 then do let weights = V.fromList (map (exp . ln . (/z)) ps) ancestors <- resampler weights let xvec = V.fromList xs let offsprings = map (xvec V.!) ancestors return $ map (, z / fromIntegral n) offsprings else -- if all weights are zero do not resample return pop -- | Systematic resampling helper. systematic :: Double -> V.Vector Double -> [Int] systematic u ps = f 0 (u / fromIntegral n) 0 0 [] where prob i = ps V.! i n = length ps inc = 1 / fromIntegral n f i _ _ _ acc | i == n = acc f i v j q acc = if v < q then f (i+1) (v+inc) j q (j-1:acc) else f i v (j + 1) (q + prob j) acc -- | Resample the population using the underlying monad and a systematic resampling scheme. -- The total weight is preserved. resampleSystematic :: (MonadSample m) => Population m a -> Population m a resampleSystematic = resampleGeneric (\ps -> (`systematic` ps) <$> random) -- | Multinomial resampler. multinomial :: MonadSample m => V.Vector Double -> m [Int] multinomial ps = replicateM (V.length ps) (categorical ps) -- | Resample the population using the underlying monad and a multinomial resampling scheme. -- The total weight is preserved. resampleMultinomial :: (MonadSample m) => Population m a -> Population m a resampleMultinomial = resampleGeneric multinomial -- | Separate the sum of weights into the 'Weighted' transformer. -- Weights are normalized after this operation. extractEvidence :: Monad m => Population m a -> Population (Weighted m) a extractEvidence m = fromWeightedList $ do pop <- lift $ runPopulation m let (xs, ps) = unzip pop let z = sum ps let ws = map (if z > 0 then (/ z) else const (1 / fromIntegral (length ps))) ps factor z return $ zip xs ws -- | Push the evidence estimator as a score to the transformed monad. -- Weights are normalized after this operation. pushEvidence :: MonadCond m => Population m a -> Population m a pushEvidence = hoist applyWeight . extractEvidence -- | A properly weighted single sample, that is one picked at random according -- to the weights, with the sum of all weights. proper :: (MonadSample m) => Population m a -> Weighted m a proper m = do pop <- runPopulation $ extractEvidence m let (xs, ps) = unzip pop index <- logCategorical $ V.fromList ps let x = xs !! index return x -- | Model evidence estimator, also known as pseudo-marginal likelihood. evidence :: (Monad m) => Population m a -> m (Log Double) evidence = extractWeight . runPopulation . extractEvidence -- | Picks one point from the population and uses model evidence as a 'score' -- in the transformed monad. -- This way a single sample can be selected from a population without -- introducing bias. collapse :: (MonadInfer m) => Population m a -> m a collapse = applyWeight . proper -- | Applies a random transformation to a population. mapPopulation :: (Monad m) => ([(a, Log Double)] -> m [(a, Log Double)]) -> Population m a -> Population m a mapPopulation f m = fromWeightedList $ runPopulation m >>= f -- | Normalizes the weights in the population so that their sum is 1. -- This transformation introduces bias. normalize :: (Monad m) => Population m a -> Population m a normalize = hoist prior . extractEvidence -- | Population average of a function, computed using unnormalized weights. popAvg :: (Monad m) => (a -> Double) -> Population m a -> m Double popAvg f p = do xs <- explicitPopulation p let ys = map (\(x,w) -> f x * w) xs let t = Data.List.sum ys return t -- | Combine a population of populations into a single population. flatten :: Monad m => Population (Population m) a -> Population m a flatten m = Population $ withWeight $ ListT t where t = f <$> (runPopulation . runPopulation) m f d = do (x,p) <- d (y,q) <- x return (y, p*q) -- | Applies a transformation to the inner monad. hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> Population m a -> Population n a hoist f = fromWeightedList . f . runPopulation