{-# LANGUAGE
MultiParamTypeClasses,
FlexibleInstances, FlexibleContexts,
CPP
#-}
{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
module Data.Random.Distribution.Categorical
( Categorical
, categorical, categoricalT
, weightedCategorical, weightedCategoricalT
, fromList, toList, totalWeight, numEvents
, fromWeightedList, fromObservations
, mapCategoricalPs, normalizeCategoricalPs
, collectEvents, collectEventsBy
) where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Data.STRef
import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical :: forall p a.
(Num p, Distribution (Categorical p) a) =>
[(p, a)] -> RVar a
categorical = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. Num p => [(p, a)] -> Categorical p a
fromList
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT :: forall p a (m :: * -> *).
(Num p, Distribution (Categorical p) a) =>
[(p, a)] -> RVarT m a
categoricalT = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. Num p => [(p, a)] -> Categorical p a
fromList
weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
weightedCategorical :: forall p a.
(Fractional p, Eq p, Distribution (Categorical p) a) =>
[(p, a)] -> RVar a
weightedCategorical = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList
weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
weightedCategoricalT :: forall p a (m :: * -> *).
(Fractional p, Eq p, Distribution (Categorical p) a) =>
[(p, a)] -> RVarT m a
weightedCategoricalT = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList
{-# INLINE fromList #-}
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList :: forall p a. Num p => [(p, a)] -> Categorical p a
fromList [(p, a)]
xs = forall p a. Vector (p, a) -> Categorical p a
Categorical (forall a. [a] -> Vector a
V.fromList (forall a. (a -> a -> a) -> [a] -> [a]
scanl1 forall {a} {b} {b}. Num a => (a, b) -> (a, b) -> (a, b)
f [(p, a)]
xs))
where f :: (a, b) -> (a, b) -> (a, b)
f (a
p0, b
_) (a
p1, b
y) = (a
p0 forall a. Num a => a -> a -> a
+ a
p1, b
y)
{-# INLINE toList #-}
toList :: (Num p) => Categorical p a -> [(p,a)]
toList :: forall p a. Num p => Categorical p a -> [(p, a)]
toList (Categorical Vector (p, a)
ds) = forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr' forall {a} {b}. Num a => (a, b) -> [(a, b)] -> [(a, b)]
g [] Vector (p, a)
ds
where
g :: (a, b) -> [(a, b)] -> [(a, b)]
g (a, b)
x [] = [(a, b)
x]
g x :: (a, b)
x@(a
p0,b
_) ((a
p1, b
y):[(a, b)]
xs) = (a, b)
x forall a. a -> [a] -> [a]
: (a
p1forall a. Num a => a -> a -> a
-a
p0,b
y) forall a. a -> [a] -> [a]
: [(a, b)]
xs
totalWeight :: Num p => Categorical p a -> p
totalWeight :: forall p a. Num p => Categorical p a -> p
totalWeight (Categorical Vector (p, a)
ds)
| forall a. Vector a -> Bool
V.null Vector (p, a)
ds = p
0
| Bool
otherwise = forall a b. (a, b) -> a
fst (forall a. Vector a -> a
V.last Vector (p, a)
ds)
numEvents :: Categorical p a -> Int
numEvents :: forall p a. Categorical p a -> Int
numEvents (Categorical Vector (p, a)
ds) = forall a. Vector a -> Int
V.length Vector (p, a)
ds
fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a
fromWeightedList :: forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList = forall p e.
(Fractional p, Eq p) =>
Categorical p e -> Categorical p e
normalizeCategoricalPs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. Num p => [(p, a)] -> Categorical p a
fromList
fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations :: forall p a. (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations = forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall i a. Num i => [a] -> i
genericLength forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a. [a] -> a
head) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [[a]]
group forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => [a] -> [a]
sort
newtype Categorical p a = Categorical (V.Vector (p, a))
deriving Categorical p a -> Categorical p a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
/= :: Categorical p a -> Categorical p a -> Bool
$c/= :: forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
== :: Categorical p a -> Categorical p a -> Bool
$c== :: forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
Eq
instance (Num p, Show p, Show a) => Show (Categorical p a) where
showsPrec :: Int -> Categorical p a -> ShowS
showsPrec Int
p Categorical p a
cat = Bool -> ShowS -> ShowS
showParen (Int
pforall a. Ord a => a -> a -> Bool
>Int
10)
( String -> ShowS
showString String
"fromList "
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall p a. Num p => Categorical p a -> [(p, a)]
toList Categorical p a
cat)
)
instance (Num p, Read p, Read a) => Read (Categorical p a) where
readsPrec :: Int -> ReadS (Categorical p a)
readsPrec Int
p = forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$ \String
str -> do
(String
"fromList", String
valStr) <- ReadS String
lex String
str
([(p, a)]
vals, String
rest) <- forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
valStr
forall (m :: * -> *) a. Monad m => a -> m a
return (forall p a. Num p => [(p, a)] -> Categorical p a
fromList [(p, a)]
vals, String
rest)
instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
rvarT :: forall (n :: * -> *). Categorical p a -> RVarT n a
rvarT (Categorical Vector (p, a)
ds)
| forall a. Vector a -> Bool
V.null Vector (p, a)
ds = forall a. HasCallStack => String -> a
error String
"categorical distribution over empty set cannot be sampled"
| Int
n forall a. Eq a => a -> a -> Bool
== Int
1 = forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (a, b) -> b
snd (forall a. Vector a -> a
V.head Vector (p, a)
ds))
| Bool
otherwise = do
p
u <- forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT p
0 (forall a b. (a, b) -> a
fst (forall a. Vector a -> a
V.last Vector (p, a)
ds))
let
p :: Int -> p
p Int
i = forall a b. (a, b) -> a
fst (Vector (p, a)
ds forall a. Vector a -> Int -> a
V.! Int
i)
x :: Int -> a
x Int
i = forall a b. (a, b) -> b
snd (Vector (p, a)
ds forall a. Vector a -> Int -> a
V.! Int
i)
findEvent :: Int -> Int -> a
findEvent Int
i Int
j
| Int
j forall a. Ord a => a -> a -> Bool
<= Int
i = Int -> a
x Int
j
| p
u forall a. Ord a => a -> a -> Bool
<= Int -> p
p Int
m = Int -> Int -> a
findEvent Int
i Int
m
| Bool
otherwise = Int -> Int -> a
findEvent (forall a. Ord a => a -> a -> a
max Int
m (Int
iforall a. Num a => a -> a -> a
+Int
1)) Int
j
where
m :: Int
m = (Int
i forall a. Num a => a -> a -> a
+ Int
j) forall a. Integral a => a -> a -> a
`div` Int
2
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! if p
u forall a. Ord a => a -> a -> Bool
<= p
0 then Int -> a
x Int
0 else Int -> Int -> a
findEvent Int
0 (Int
nforall a. Num a => a -> a -> a
-Int
1)
where n :: Int
n = forall a. Vector a -> Int
V.length Vector (p, a)
ds
instance Functor (Categorical p) where
fmap :: forall a b. (a -> b) -> Categorical p a -> Categorical p b
fmap a -> b
f (Categorical Vector (p, a)
ds) = forall p a. Vector (p, a) -> Categorical p a
Categorical (forall a b. (a -> b) -> Vector a -> Vector b
V.map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second a -> b
f) Vector (p, a)
ds)
instance Foldable (Categorical p) where
foldMap :: forall m a. Monoid m => (a -> m) -> Categorical p a -> m
foldMap a -> m
f (Categorical Vector (p, a)
ds) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall a. Vector a -> [a]
V.toList Vector (p, a)
ds)
instance Traversable (Categorical p) where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Categorical p a -> f (Categorical p b)
traverse a -> f b
f (Categorical Vector (p, a)
ds) = forall p a. Vector (p, a) -> Categorical p a
Categorical forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> Vector a
V.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\(p
p,a
e) -> (\b
e' -> (p
p,b
e')) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
e) (forall a. Vector a -> [a]
V.toList Vector (p, a)
ds)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Categorical p (f a) -> f (Categorical p a)
sequenceA (Categorical Vector (p, f a)
ds) = forall p a. Vector (p, a) -> Categorical p a
Categorical forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> Vector a
V.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\(p
p,f a
e) -> (\a
e' -> (p
p,a
e')) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
e) (forall a. Vector a -> [a]
V.toList Vector (p, f a)
ds)
instance Fractional p => Monad (Categorical p) where
return :: forall a. a -> Categorical p a
return a
x = forall p a. Vector (p, a) -> Categorical p a
Categorical (forall a. a -> Vector a
V.singleton (p
1, a
x))
#if __GLASGOW_HASKELL__ < 808
fail _ = Categorical V.empty
#endif
Categorical p a
xs >>= :: forall a b.
Categorical p a -> (a -> Categorical p b) -> Categorical p b
>>= a -> Categorical p b
f = forall p a. Num p => [(p, a)] -> Categorical p a
fromList forall a b. (a -> b) -> a -> b
$ do
(p
p, a
x) <- forall p a. Num p => Categorical p a -> [(p, a)]
toList Categorical p a
xs
(p
q, b
y) <- forall p a. Num p => Categorical p a -> [(p, a)]
toList (a -> Categorical p b
f a
x)
forall (m :: * -> *) a. Monad m => a -> m a
return (p
p forall a. Num a => a -> a -> a
* p
q, b
y)
instance Fractional p => Applicative (Categorical p) where
pure :: forall a. a -> Categorical p a
pure = forall (m :: * -> *) a. Monad m => a -> m a
return
<*> :: forall a b.
Categorical p (a -> b) -> Categorical p a -> Categorical p b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs :: forall p q e.
(Num p, Num q) =>
(p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs p -> q
f = forall p a. Num p => [(p, a)] -> Categorical p a
fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first p -> q
f) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. Num p => Categorical p a -> [(p, a)]
toList
normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e
normalizeCategoricalPs :: forall p e.
(Fractional p, Eq p) =>
Categorical p e -> Categorical p e
normalizeCategoricalPs orig :: Categorical p e
orig@(Categorical Vector (p, e)
ds)
| p
ps forall a. Eq a => a -> a -> Bool
== p
0 = forall p a. Vector (p, a) -> Categorical p a
Categorical forall a. Vector a
V.empty
| Bool
otherwise = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
STRef s p
lastP <- forall a s. a -> ST s (STRef s a)
newSTRef p
0
STRef s Int
nDups <- forall a s. a -> ST s (STRef s a)
newSTRef Int
0
MVector s (p, e)
normalized <- forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector (p, e)
ds
let n :: Int
n = forall a. Vector a -> Int
V.length Vector (p, e)
ds
skip :: ST s ()
skip = forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Int
nDups (Int
1forall a. Num a => a -> a -> a
+)
save :: Int -> p -> e -> ST s ()
save Int
i p
p e
x = do
Int
d <- forall s a. STRef s a -> ST s a
readSTRef STRef s Int
nDups
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (p, e)
normalized (Int
iforall a. Num a => a -> a -> a
-Int
d) (p
p, e
x)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ do
let (p
p,e
x) = Vector (p, e)
ds forall a. Vector a -> Int -> a
V.! Int
i
p
p0 <- forall s a. STRef s a -> ST s a
readSTRef STRef s p
lastP
if p
p forall a. Eq a => a -> a -> Bool
== p
p0
then ST s ()
skip
else do
Int -> p -> e -> ST s ()
save Int
i (p
p forall a. Num a => a -> a -> a
* p
scale) e
x
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s p
lastP forall a b. (a -> b) -> a -> b
$! p
p
| Int
i <- [Int
0..Int
nforall a. Num a => a -> a -> a
-Int
1]
]
Int
d <- forall s a. STRef s a -> ST s a
readSTRef STRef s Int
nDups
let n' :: Int
n' = Int
nforall a. Num a => a -> a -> a
-Int
d
(p
_,e
lastX) <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (p, e)
normalized (Int
n'forall a. Num a => a -> a -> a
-Int
1)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (p, e)
normalized (Int
n'forall a. Num a => a -> a -> a
-Int
1) (p
1,e
lastX)
forall p a. Vector (p, a) -> Categorical p a
Categorical forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (forall s a. Int -> Int -> MVector s a -> MVector s a
MV.unsafeSlice Int
0 Int
n' MVector s (p, e)
normalized)
where
ps :: p
ps = forall p a. Num p => Categorical p a -> p
totalWeight Categorical p e
orig
scale :: p
scale = forall a. Fractional a => a -> a
recip p
ps
#if __GLASGOW_HASKELL__ < 706
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
v <- readSTRef x
let fv = f v
fv `seq` writeSTRef x fv
#endif
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents :: forall e p.
(Ord e, Num p, Ord p) =>
Categorical p e -> Categorical p e
collectEvents = forall p e.
Num p =>
(e -> e -> Ordering)
-> ([(p, e)] -> (p, e)) -> Categorical p e -> Categorical p e
collectEventsBy forall a. Ord a => a -> a -> Ordering
compare ((forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall a. [a] -> a
head) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [(a, b)] -> ([a], [b])
unzip)
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy :: forall p e.
Num p =>
(e -> e -> Ordering)
-> ([(p, e)] -> (p, e)) -> Categorical p e -> Categorical p e
collectEventsBy e -> e -> Ordering
compareE [(p, e)] -> (p, e)
combine =
forall p a. Num p => [(p, a)] -> Categorical p a
fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map [(p, e)] -> (p, e)
combine forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. [(a, e)] -> [[(a, e)]]
groupEvents forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. [(a, e)] -> [(a, e)]
sortEvents forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p a. Num p => Categorical p a -> [(p, a)]
toList
where
groupEvents :: [(a, e)] -> [[(a, e)]]
groupEvents = forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (\(a, e)
x (a, e)
y -> forall a b. (a, b) -> b
snd (a, e)
x e -> e -> Ordering
`compareE` forall a b. (a, b) -> b
snd (a, e)
y forall a. Eq a => a -> a -> Bool
== Ordering
EQ)
sortEvents :: [(a, e)] -> [(a, e)]
sortEvents = forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (e -> e -> Ordering
compareE forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> b
snd)