{-# LANGUAGE GADTs, MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
module Data.Random.Distribution.Multinomial where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Binomial

multinomial :: Distribution (Multinomial p) [a] => [p] -> a -> RVar [a]
multinomial :: forall p a.
Distribution (Multinomial p) [a] =>
[p] -> a -> RVar [a]
multinomial [p]
ps a
n = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)

multinomialT :: Distribution (Multinomial p) [a] => [p] -> a -> RVarT m [a]
multinomialT :: forall p a (m :: * -> *).
Distribution (Multinomial p) [a] =>
[p] -> a -> RVarT m [a]
multinomialT [p]
ps a
n = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)

data Multinomial p a where
    Multinomial :: [p] -> a -> Multinomial p [a]

instance (Num a, Eq a, Fractional p, Distribution (Binomial p) a) => Distribution (Multinomial p) [a] where
    -- TODO: implement faster version based on Categorical for small n, large (length ps)
    rvarT :: forall (n :: * -> *). Multinomial p [a] -> RVarT n [a]
rvarT (Multinomial [p]
ps0 a
t) = forall {t} {b} {c} {m :: * -> *}.
(Eq t, Distribution (Binomial b) t, Fractional b, Num t) =>
t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go a
t [p]
ps0 (forall {a}. Num a => [a] -> [a]
tailSums [p]
ps0) forall a. a -> a
id
        where
            go :: t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go t
_ []     [b]
_            [t] -> c
f = forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [])
            go t
n [b
_]    [b]
_            [t] -> c
f = forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [t
n])
            go t
0 (b
_:[b]
ps) (b
_   :[b]
psums) [t] -> c
f = t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go t
0 [b]
ps [b]
psums ([t] -> c
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
0forall a. a -> [a] -> [a]
:))
            go t
n (b
p:[b]
ps) (b
psum:[b]
psums) [t] -> c
f = do
                t
x <- forall b a (m :: * -> *).
Distribution (Binomial b) a =>
a -> b -> RVarT m a
binomialT t
n (b
p forall a. Fractional a => a -> a -> a
/ b
psum)
                t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go (t
nforall a. Num a => a -> a -> a
-t
x) [b]
ps [b]
psums ([t] -> c
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
xforall a. a -> [a] -> [a]
:))

            go t
_ [b]
_ [b]
_ [t] -> c
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial: programming error! this case should be impossible!"

            -- less wasteful version of (map sum . tails)
            tailSums :: [a] -> [a]
tailSums [] = [a
0]
            tailSums (a
x:[a]
xs) = case [a] -> [a]
tailSums [a]
xs of
                (a
s:[a]
rest) -> (a
xforall a. Num a => a -> a -> a
+a
s)forall a. a -> [a] -> [a]
:a
sforall a. a -> [a] -> [a]
:[a]
rest
                [a]
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial/tailSums: programming error! this case should be impossible!"