module Data.Random.Distribution.Multinomial where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Binomial
multinomial :: Distribution (Multinomial p) [a] => [p] -> a -> RVar [a]
multinomial ps n = rvar (Multinomial ps n)
multinomialT :: Distribution (Multinomial p) [a] => [p] -> a -> RVarT m [a]
multinomialT ps n = rvarT (Multinomial ps n)
data Multinomial p a where
Multinomial :: [p] -> a -> Multinomial p [a]
instance (Num a, Eq a, Fractional p, Distribution (Binomial p) a) => Distribution (Multinomial p) [a] where
rvarT (Multinomial ps0 t) = go t ps0 (tailSums ps0) id
where
go _ [] _ f = return (f [])
go n [_] _ f = return (f [n])
go 0 (_:ps) (_ :psums) f = go 0 ps psums (f . (0:))
go n (p:ps) (psum:psums) f = do
x <- binomialT n (p / psum)
go (nx) ps psums (f . (x:))
go _ _ _ _ = error "rvar/Multinomial: programming error! this case should be impossible!"
tailSums [] = [0]
tailSums (x:xs) = case tailSums xs of
(s:rest) -> (x+s):s:rest
_ -> error "rvar/Multinomial/tailSums: programming error! this case should be impossible!"