{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
module Language.Kuifje.Distribution where
import Prelude hiding (filter, foldr, return, (>>=))
import Data.List (genericLength)
import Data.Map.Strict
type Prob = Rational
newtype Dist a = D { runD :: Map a Prob }
type Hyper a = Dist (Dist a)
fmap :: (Ord b) => (a -> b) -> Dist a -> Dist b
fmap f dx = dx >>= (return . f)
return :: (Ord a) => a -> Dist a
return x = D $ singleton x 1
point :: Ord a => a -> Dist a
point = return
(>>=) :: (Ord b) => Dist a -> (a -> Dist b) -> Dist b
d >>= f = D $ fromListWith (+) [(y, p * q) | (x, p) <- toList $ runD d, (y, q) <- toList $ runD (f x)]
join :: (Ord a) => Hyper a -> Dist a
join x = x >>= id
instance Ord a => Eq (Dist a) where
d1 == d2 = unpackD d1 == unpackD d2
instance Ord a => Ord (Dist a) where
d1 <= d2 = unpackD d1 <= unpackD d2
uniform :: (Ord a) => [a] -> Dist a
uniform l = D $ fromListWith (+) [(x, 1 / genericLength l) | x <- l]
choose :: (Ord a) => Prob -> a -> a -> Dist a
choose p x y = D $ fromListWith (+) [(x, p), (y, 1 - p)]
unpackD :: Dist a -> Map a Prob
unpackD = removeZeroes . runD
where
removeZeroes = filter (/= 0)
reduction :: Dist a -> Dist a
reduction = D . unpackD
weight :: Dist a -> Prob
weight (D l) = foldr (+) 0 l