mcmc-samplers- Combinators for MCMC sampling

Safe HaskellNone




Sampler for Gaussian Mixture Model

Here is the code in the Hakaru language for generating the data used in this example:

p <- unconditioned (beta 2 2)
[m1,m2] <- replicateM 2 $ unconditioned (normal 100 30)
[s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2)
let makePoint = do        
      b <- unconditioned (bern p)
      unconditioned (ifThenElse b (normal m1 s1)
                                  (normal m2 s2))
replicateM nPoints makePoint



data GaussianMixtureState




labels :: [Bool]

The list of observation labels

gaussParams :: ((Double, Double), (Double, Double))

The parameters of the two Gaussians (mean, covariance)

bernParam :: Double

The mixture proportion

obs :: [Double]

The observed data


Focus combinators

focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState
focusLabels t = makeTarget dens
    where dens (GMM l _ p _) = density t (p,l)

focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState
focusGaussParams t = makeTarget (density t . gaussParams)

focusBernParam :: Target Double -> Target GaussianMixtureState
focusBernParam t = makeTarget (density t . bernParam)

focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double])
         -> Target GaussianMixtureState
focusObs t = makeTarget dens
    where dens (GMM l gps _ o) = density t (l, gps, o)

Record field targets

labelsTarget :: Target (Double, [Bool])
labelsTarget = makeTarget $ (p,ls) -> product $ map (density $ bern p) ls

gaussParamsTarget :: Target ((Double, Double), (Double, Double))
gaussParamsTarget = makeTarget dens
    where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2
          mdens m = density (normal 100 900) m
          cdens c = density (uniform 0 200) c

bernParamTarget :: Target Double
bernParamTarget = fromProposal (beta 2 2)

obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double])
obsTarget  = makeTarget dens
    where dens (ls, ((m1, c1), (m2, c2)), os) 
              = let ols = zip os ls
                    gauss l = if l then normal m1 (c1*c1) else normal m2 (c2*c2)
                in product $ map ((o,l) -> density (gauss l) o) ols

Target density factors

labelsFactor :: Target GaussianMixtureState
labelsFactor = focusLabels labelsTarget

gaussParamsFactor :: Target GaussianMixtureState
gaussParamsFactor = focusGaussParams gaussParamsTarget

bernParamFactor :: Target GaussianMixtureState
bernParamFactor = focusBernParam bernParamTarget

obsFactor :: Target GaussianMixtureState
obsFactor = focusObs obsTarget

Target density

gmmTarget :: Target GaussianMixtureState
gmmTarget = makeTarget $ productDensity 
            [labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor]


Proposal update boilerplate

updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState
updateLabels f x = makeProposal dens sf
    where dens y = density (f $ labels x) (labels y)
          sf g = do newLabels <- sampleFrom (f $ labels x) g
                    return x { labels = newLabels }

updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)))
                     -> GaussianMixtureState -> Proposal GaussianMixtureState
updateGaussParams f x = makeProposal dens sf
    where dens y = density (f $ gaussParams x) (gaussParams y)
          sf g = do newParams <- sampleFrom (f $ gaussParams x) g
                    return x { gaussParams = newParams }

updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState
updateBernParam f x = makeProposal dens sf
    where dens y = density (f $ bernParam x) (bernParam y)
          sf g = do newParam <- sampleFrom (f $ bernParam x) g
                    return x { bernParam = newParam }

Field proposals

labelsProposal :: [Bool] -> Proposal [Bool]
labelsProposal ls = chooseProposal nPoints (n -> updateNth n flipBool ls)
    where flipBool bn = if bn then bern 0 else bern 1

gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))
gaussParamsProposal params = mixProposals $ zip [m1p, c1p, m2p, c2p] (repeat 1)
    where condProp c = normal c 1
          m1p = updateFirst (updateFirst condProp) params
          c1p = updateFirst (updateSecond condProp) params
          m2p = updateSecond (updateFirst condProp) params
          c2p = updateSecond (updateSecond condProp) params

bernParamProposal :: Double -> Proposal Double
bernParamProposal p = uniform (p2) (1-p2)

Field updaters

labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
labelsUpdater = updateLabels labelsProposal

gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
gaussParamsUpdater = updateGaussParams gaussParamsProposal

bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
bernParamUpdater = updateBernParam bernParamProposal

The combined proposal

gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState
gmmProposal = mixCondProposals $ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2]

Running the sampler

Transition kernel

gmmMH :: Step GaussianMixtureState
gmmMH = metropolisHastings gmmTarget gmmProposal

Visualization methods

histogram :: Ord a => [a] -> Map.Map a Int
histogram ls = foldl addElem Map.empty ls
               where addElem m e = Map.insertWith (+) e 1 m

printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double)
printFields = let f s = (labels s, gaussParams s, bernParam s) in map f 

printLabelN :: Int -> PrintF GaussianMixtureState Bool
printLabelN n = let f s = labels s !! (n-1) in map f

compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool)
compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f

printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO ()
printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls

batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO ()
batchHist f n = pack (printHist f) $ inBatches (printHist f) n


nPoints :: Int
nPoints = 6

sampleData :: [Double]
sampleData = [ 63.13941114139962, 132.02763712240528
             , 62.59642260289356, 132.2616834236893
             , 64.10610391933461, 62.143820541377934 ]

gmmStart :: GaussianMixtureState
gmmStart = GMM { labels = [True, True, True, False, False, False],
                 gaussParams = ((63, 100), (132, 100)),
                 bernParam = 0.5,
                 obs = sampleData }

gmmTest :: IO ()
gmmTest = do
  g <- MWC.createSystemRandom
  let a = batchHist (compareLabels 5 6) 50
      e = every 50 a
      c = every 50 collect
  ls <- walk gmmMH gmmStart (10^6) g c
  putStrLn "Done"
  print $ take 20 (map labels ls)