{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}

-- |
-- Module      : Control.Monad.Bayes.Enumerator
-- Description : Exhaustive enumeration of discrete random variables
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Enumerator
  ( Enumerator,
    logExplicit,
    explicit,
    evidence,
    mass,
    compact,
    enumerator,
    enumerate,
    expectation,
    normalForm,
    toEmpirical,
    toEmpiricalWeighted,
    normalizeWeights,
    enumerateToDistribution,
    removeZeros,
    fromList,
  )
where

import Control.Applicative (Alternative)
import Control.Arrow (second)
import Control.Monad (MonadPlus)
import Control.Monad.Bayes.Class
  ( MonadDistribution (bernoulli, categorical, logCategorical, random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Writer (WriterT (..))
import Data.AEq (AEq, (===), (~==))
import Data.List (sortOn)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Monoid (Product (..))
import Data.Ord (Down (Down))
import Data.Vector qualified as VV
import Data.Vector.Generic qualified as V
import Numeric.Log as Log (Log (..), sum)

-- | An exact inference transformer that integrates
-- discrete random variables by enumerating all execution paths.
newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a)
  deriving newtype ((forall a b. (a -> b) -> Enumerator a -> Enumerator b)
-> (forall a b. a -> Enumerator b -> Enumerator a)
-> Functor Enumerator
forall a b. a -> Enumerator b -> Enumerator a
forall a b. (a -> b) -> Enumerator a -> Enumerator b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Enumerator a -> Enumerator b
fmap :: forall a b. (a -> b) -> Enumerator a -> Enumerator b
$c<$ :: forall a b. a -> Enumerator b -> Enumerator a
<$ :: forall a b. a -> Enumerator b -> Enumerator a
Functor, Functor Enumerator
Functor Enumerator =>
(forall a. a -> Enumerator a)
-> (forall a b.
    Enumerator (a -> b) -> Enumerator a -> Enumerator b)
-> (forall a b c.
    (a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator b)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator a)
-> Applicative Enumerator
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> Enumerator a
pure :: forall a. a -> Enumerator a
$c<*> :: forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
<*> :: forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
$cliftA2 :: forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
liftA2 :: forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
$c*> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
*> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
$c<* :: forall a b. Enumerator a -> Enumerator b -> Enumerator a
<* :: forall a b. Enumerator a -> Enumerator b -> Enumerator a
Applicative, Applicative Enumerator
Applicative Enumerator =>
(forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b)
-> (forall a b. Enumerator a -> Enumerator b -> Enumerator b)
-> (forall a. a -> Enumerator a)
-> Monad Enumerator
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
>>= :: forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
$c>> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
>> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
$creturn :: forall a. a -> Enumerator a
return :: forall a. a -> Enumerator a
Monad, Applicative Enumerator
Applicative Enumerator =>
(forall a. Enumerator a)
-> (forall a. Enumerator a -> Enumerator a -> Enumerator a)
-> (forall a. Enumerator a -> Enumerator [a])
-> (forall a. Enumerator a -> Enumerator [a])
-> Alternative Enumerator
forall a. Enumerator a
forall a. Enumerator a -> Enumerator [a]
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (f :: * -> *).
Applicative f =>
(forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
$cempty :: forall a. Enumerator a
empty :: forall a. Enumerator a
$c<|> :: forall a. Enumerator a -> Enumerator a -> Enumerator a
<|> :: forall a. Enumerator a -> Enumerator a -> Enumerator a
$csome :: forall a. Enumerator a -> Enumerator [a]
some :: forall a. Enumerator a -> Enumerator [a]
$cmany :: forall a. Enumerator a -> Enumerator [a]
many :: forall a. Enumerator a -> Enumerator [a]
Alternative, Monad Enumerator
Alternative Enumerator
(Alternative Enumerator, Monad Enumerator) =>
(forall a. Enumerator a)
-> (forall a. Enumerator a -> Enumerator a -> Enumerator a)
-> MonadPlus Enumerator
forall a. Enumerator a
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (m :: * -> *).
(Alternative m, Monad m) =>
(forall a. m a) -> (forall a. m a -> m a -> m a) -> MonadPlus m
$cmzero :: forall a. Enumerator a
mzero :: forall a. Enumerator a
$cmplus :: forall a. Enumerator a -> Enumerator a -> Enumerator a
mplus :: forall a. Enumerator a -> Enumerator a -> Enumerator a
MonadPlus)

instance MonadDistribution Enumerator where
  random :: Enumerator Double
random = [Char] -> Enumerator Double
forall a. HasCallStack => [Char] -> a
error [Char]
"Infinitely supported random variables not supported in Enumerator"
  bernoulli :: Double -> Enumerator Bool
bernoulli Double
p = [(Bool, Log Double)] -> Enumerator Bool
forall a. [(a, Log Double)] -> Enumerator a
fromList [(Bool
True, (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) Double
p), (Bool
False, (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p))]
  categorical :: forall (v :: * -> *). Vector v Double => v Double -> Enumerator Int
categorical v Double
v = [(Int, Log Double)] -> Enumerator Int
forall a. [(a, Log Double)] -> Enumerator a
fromList ([(Int, Log Double)] -> Enumerator Int)
-> [(Int, Log Double)] -> Enumerator Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Log Double] -> [(Int, Log Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] ([Log Double] -> [(Int, Log Double)])
-> [Log Double] -> [(Int, Log Double)]
forall a b. (a -> b) -> a -> b
$ (Double -> Log Double) -> [Double] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double)
-> (Double -> Double) -> Double -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
log) (v Double -> [Double]
forall (v :: * -> *) a. Vector v a => v a -> [a]
V.toList v Double
v)

instance MonadFactor Enumerator where
  score :: Log Double -> Enumerator ()
score Log Double
w = [((), Log Double)] -> Enumerator ()
forall a. [(a, Log Double)] -> Enumerator a
fromList [((), Log Double
w)]

instance MonadMeasure Enumerator

-- | Construct Enumerator from a list of values and associated weights.
fromList :: [(a, Log Double)] -> Enumerator a
fromList :: forall a. [(a, Log Double)] -> Enumerator a
fromList = WriterT (Product (Log Double)) [] a -> Enumerator a
forall a. WriterT (Product (Log Double)) [] a -> Enumerator a
Enumerator (WriterT (Product (Log Double)) [] a -> Enumerator a)
-> ([(a, Log Double)] -> WriterT (Product (Log Double)) [] a)
-> [(a, Log Double)]
-> Enumerator a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Product (Log Double))] -> WriterT (Product (Log Double)) [] a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT ([(a, Product (Log Double))]
 -> WriterT (Product (Log Double)) [] a)
-> ([(a, Log Double)] -> [(a, Product (Log Double))])
-> [(a, Log Double)]
-> WriterT (Product (Log Double)) [] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> (a, Product (Log Double)))
-> [(a, Log Double)] -> [(a, Product (Log Double))]
forall a b. (a -> b) -> [a] -> [b]
map ((Log Double -> Product (Log Double))
-> (a, Log Double) -> (a, Product (Log Double))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Log Double -> Product (Log Double)
forall a. a -> Product a
Product)

-- | Returns the posterior as a list of weight-value pairs without any post-processing,
-- such as normalization or aggregation
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit :: forall a. Enumerator a -> [(a, Log Double)]
logExplicit (Enumerator WriterT (Product (Log Double)) [] a
m) = ((a, Product (Log Double)) -> (a, Log Double))
-> [(a, Product (Log Double))] -> [(a, Log Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Product (Log Double) -> Log Double)
-> (a, Product (Log Double)) -> (a, Log Double)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Product (Log Double) -> Log Double
forall a. Product a -> a
getProduct) ([(a, Product (Log Double))] -> [(a, Log Double)])
-> [(a, Product (Log Double))] -> [(a, Log Double)]
forall a b. (a -> b) -> a -> b
$ WriterT (Product (Log Double)) [] a -> [(a, Product (Log Double))]
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) [] a
m

-- | Same as `toList`, only weights are converted from log-domain.
explicit :: Enumerator a -> [(a, Double)]
explicit :: forall a. Enumerator a -> [(a, Double)]
explicit = ((a, Log Double) -> (a, Double))
-> [(a, Log Double)] -> [(a, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((Log Double -> Double) -> (a, Log Double) -> (a, Double)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln)) ([(a, Log Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Returns the model evidence, that is sum of all weights.
evidence :: Enumerator a -> Log Double
evidence :: forall a. Enumerator a -> Log Double
evidence = [Log Double] -> Log Double
forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
Log.sum ([Log Double] -> Log Double)
-> (Enumerator a -> [Log Double]) -> Enumerator a -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> Log Double)
-> [(a, Log Double)] -> [Log Double]
forall a b. (a -> b) -> [a] -> [b]
map (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd ([(a, Log Double)] -> [Log Double])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [Log Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Normalized probability mass of a specific value.
mass :: (Ord a) => Enumerator a -> a -> Double
mass :: forall a. Ord a => Enumerator a -> a -> Double
mass Enumerator a
d = a -> Double
f
  where
    f :: a -> Double
f a
a = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe Double
0 (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ a -> [(a, Double)] -> Maybe Double
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
a [(a, Double)]
m
    m :: [(a, Double)]
m = Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator Enumerator a
d

-- | Aggregate weights of equal values.
-- The resulting list is sorted ascendingly according to values.
compact :: (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact :: forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact = ((a, r) -> Down r) -> [(a, r)] -> [(a, r)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (r -> Down r
forall a. a -> Down a
Down (r -> Down r) -> ((a, r) -> r) -> (a, r) -> Down r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, r) -> r
forall a b. (a, b) -> b
snd) ([(a, r)] -> [(a, r)])
-> ([(a, r)] -> [(a, r)]) -> [(a, r)] -> [(a, r)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map a r -> [(a, r)]
forall k a. Map k a -> [(k, a)]
Map.toAscList (Map a r -> [(a, r)])
-> ([(a, r)] -> Map a r) -> [(a, r)] -> [(a, r)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> r -> r) -> [(a, r)] -> Map a r
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith r -> r -> r
forall a. Num a => a -> a -> a
(+)

-- | Aggregate and normalize of weights.
-- The resulting list is sorted ascendingly according to values.
--
-- > enumerator = compact . explicit
enumerator, enumerate :: (Ord a) => Enumerator a -> [(a, Double)]
enumerator :: forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator Enumerator a
d = ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)]) -> [(a, Double)] -> [(a, Double)]
forall a b. (a -> b) -> a -> b
$ [(a, Double)] -> [(a, Double)]
forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact ([a] -> [Double] -> [(a, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Double]
ws)
  where
    ([a]
xs, [Double]
ws) = ([Log Double] -> [Double])
-> ([a], [Log Double]) -> ([a], [Double])
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Log Double -> Double) -> [Log Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) ([Log Double] -> [Double])
-> ([Log Double] -> [Log Double]) -> [Log Double] -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Log Double] -> [Log Double]
forall b. Fractional b => [b] -> [b]
normalize) (([a], [Log Double]) -> ([a], [Double]))
-> ([a], [Log Double]) -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit Enumerator a
d)

-- | deprecated synonym
enumerate :: forall a. Ord a => Enumerator a -> [(a, Double)]
enumerate = Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator

-- | Expectation of a given function computed using normalized weights.
expectation :: (a -> Double) -> Enumerator a -> Double
expectation :: forall a. (a -> Double) -> Enumerator a -> Double
expectation a -> Double
f = [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum ([Double] -> Double)
-> (Enumerator a -> [Double]) -> Enumerator a -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Log Double) -> Double) -> [(a, Log Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x, Log Double
w) -> a -> Double
f a
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) Log Double
w) ([(a, Log Double)] -> [Double])
-> (Enumerator a -> [(a, Log Double)]) -> Enumerator a -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Log Double)] -> [(a, Log Double)]
forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights ([(a, Log Double)] -> [(a, Log Double)])
-> (Enumerator a -> [(a, Log Double)])
-> Enumerator a
-> [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit

normalize :: (Fractional b) => [b] -> [b]
normalize :: forall b. Fractional b => [b] -> [b]
normalize [b]
xs = (b -> b) -> [b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
z) [b]
xs
  where
    z :: b
z = [b] -> b
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum [b]
xs

-- | Divide all weights by their sum.
normalizeWeights :: (Fractional b) => [(a, b)] -> [(a, b)]
normalizeWeights :: forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights [(a, b)]
ls = [a] -> [b] -> [(a, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [b]
ps
  where
    ([a]
xs, [b]
ws) = [(a, b)] -> ([a], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, b)]
ls
    ps :: [b]
ps = [b] -> [b]
forall b. Fractional b => [b] -> [b]
normalize [b]
ws

-- | 'compact' followed by removing values with zero weight.
normalForm :: (Ord a) => Enumerator a -> [(a, Double)]
normalForm :: forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm = ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Double)]) -> Enumerator a -> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Double)] -> [(a, Double)]
forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact ([(a, Double)] -> [(a, Double)])
-> (Enumerator a -> [(a, Double)]) -> Enumerator a -> [(a, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumerator a -> [(a, Double)]
forall a. Enumerator a -> [(a, Double)]
explicit

toEmpirical :: (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical :: forall b a. (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical [a]
ls = [(a, b)] -> [(a, b)]
forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights ([(a, b)] -> [(a, b)]) -> [(a, b)] -> [(a, b)]
forall a b. (a -> b) -> a -> b
$ [(a, b)] -> [(a, b)]
forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact ([a] -> [b] -> [(a, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
ls (b -> [b]
forall a. a -> [a]
repeat b
1))

toEmpiricalWeighted :: (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)]
toEmpiricalWeighted :: forall b a. (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)]
toEmpiricalWeighted = [(a, b)] -> [(a, b)]
forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights ([(a, b)] -> [(a, b)])
-> ([(a, b)] -> [(a, b)]) -> [(a, b)] -> [(a, b)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, b)] -> [(a, b)]
forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact

enumerateToDistribution :: (MonadDistribution n) => Enumerator a -> n a
enumerateToDistribution :: forall (n :: * -> *) a. MonadDistribution n => Enumerator a -> n a
enumerateToDistribution Enumerator a
model = do
  let samples :: [(a, Log Double)]
samples = Enumerator a -> [(a, Log Double)]
forall a. Enumerator a -> [(a, Log Double)]
logExplicit Enumerator a
model
  let ([a]
support, [Log Double]
logprobs) = [(a, Log Double)] -> ([a], [Log Double])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
samples
  Int
i <- Vector (Log Double) -> n Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> n Int
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> m Int
logCategorical (Vector (Log Double) -> n Int) -> Vector (Log Double) -> n Int
forall a b. (a -> b) -> a -> b
$ [Log Double] -> Vector (Log Double)
forall a. [a] -> Vector a
VV.fromList [Log Double]
logprobs
  a -> n a
forall a. a -> n a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> n a) -> a -> n a
forall a b. (a -> b) -> a -> b
$ [a]
support [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Int
i

removeZeros :: Enumerator a -> Enumerator a
removeZeros :: forall a. Enumerator a -> Enumerator a
removeZeros (Enumerator (WriterT [(a, Product (Log Double))]
a)) = WriterT (Product (Log Double)) [] a -> Enumerator a
forall a. WriterT (Product (Log Double)) [] a -> Enumerator a
Enumerator (WriterT (Product (Log Double)) [] a -> Enumerator a)
-> WriterT (Product (Log Double)) [] a -> Enumerator a
forall a b. (a -> b) -> a -> b
$ [(a, Product (Log Double))] -> WriterT (Product (Log Double)) [] a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT ([(a, Product (Log Double))]
 -> WriterT (Product (Log Double)) [] a)
-> [(a, Product (Log Double))]
-> WriterT (Product (Log Double)) [] a
forall a b. (a -> b) -> a -> b
$ ((a, Product (Log Double)) -> Bool)
-> [(a, Product (Log Double))] -> [(a, Product (Log Double))]
forall a. (a -> Bool) -> [a] -> [a]
filter ((\(Product Log Double
x) -> Log Double
x Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Log Double
0) (Product (Log Double) -> Bool)
-> ((a, Product (Log Double)) -> Product (Log Double))
-> (a, Product (Log Double))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Product (Log Double)) -> Product (Log Double)
forall a b. (a, b) -> b
snd) [(a, Product (Log Double))]
a

instance (Ord a) => Eq (Enumerator a) where
  Enumerator a
p == :: Enumerator a -> Enumerator a -> Bool
== Enumerator a
q = Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p [(a, Double)] -> [(a, Double)] -> Bool
forall a. Eq a => a -> a -> Bool
== Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q

instance (Ord a) => AEq (Enumerator a) where
  Enumerator a
p === :: Enumerator a -> Enumerator a -> Bool
=== Enumerator a
q = [a]
xs [a] -> [a] -> Bool
forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps [Double] -> [Double] -> Bool
forall a. AEq a => a -> a -> Bool
=== [Double]
qs
    where
      ([a]
xs, [Double]
ps) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p)
      ([a]
ys, [Double]
qs) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip (Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q)
  Enumerator a
p ~== :: Enumerator a -> Enumerator a -> Bool
~== Enumerator a
q = [a]
xs [a] -> [a] -> Bool
forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps [Double] -> [Double] -> Bool
forall a. AEq a => a -> a -> Bool
~== [Double]
qs
    where
      ([a]
xs, [Double]
ps) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(a, Double)] -> ([a], [Double]))
-> [(a, Double)] -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ((a, Double) -> Bool) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)]) -> [(a, Double)] -> [(a, Double)]
forall a b. (a -> b) -> a -> b
$ Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p
      ([a]
ys, [Double]
qs) = [(a, Double)] -> ([a], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(a, Double)] -> ([a], [Double]))
-> [(a, Double)] -> ([a], [Double])
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> Bool) -> [(a, Double)] -> [(a, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ((a, Double) -> Bool) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) ([(a, Double)] -> [(a, Double)]) -> [(a, Double)] -> [(a, Double)]
forall a b. (a -> b) -> a -> b
$ Enumerator a -> [(a, Double)]
forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q