{-# LANGUAGE FlexibleContexts #-}
-- |
-- Module    : System.Random.MWC.CondensedTable
-- Copyright : (c) 2012 Aleksey Khudyakov
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Table-driven generation of random variates.  This approach can
-- generate random variates in /O(1)/ time for the supported
-- distributions, at a modest cost in initialization time.
module System.Random.MWC.CondensedTable (
    -- * Condensed tables
    CondensedTable
  , CondensedTableV
  , CondensedTableU
  , genFromTable
    -- * Constructors for tables
  , tableFromProbabilities
  , tableFromWeights
  , tableFromIntWeights
    -- ** Disrete distributions
  , tablePoisson
  , tableBinomial
    -- * References
    -- $references
  ) where

import Control.Arrow           (second,(***))

import Data.Word
import Data.Int
import Data.Bits
import qualified Data.Vector.Generic         as G
import           Data.Vector.Generic           ((++))
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed         as U
import qualified Data.Vector                 as V
import Data.Vector.Generic (Vector)
import Numeric.SpecFunctions (logFactorial)
import System.Random.Stateful

import Prelude hiding ((++))



-- | A lookup table for arbitrary discrete distributions. It allows
-- the generation of random variates in /O(1)/. Note that probability
-- is quantized in units of @1/2^32@, and all distributions with
-- infinite support (e.g. Poisson) should be truncated.
data CondensedTable v a =
  CondensedTable
  {-# UNPACK #-} !Word64 !(v a) -- Lookup limit and first table
  {-# UNPACK #-} !Word64 !(v a) -- Second table
  {-# UNPACK #-} !Word64 !(v a) -- Third table
  !(v a)                        -- Last table

-- Implementation note. We have to store lookup limit in Word64 since
-- we need to accomodate two cases. First is when we have no values in
-- lookup table, second is when all elements are there
--
-- Both are pretty easy to realize. For first one probability of every
-- outcome should be less then 1/256, latter arise when probabilities
-- of two outcomes are [0.5,0.5]

-- | A 'CondensedTable' that uses unboxed vectors.
type CondensedTableU = CondensedTable U.Vector

-- | A 'CondensedTable' that uses boxed vectors, and is able to hold
-- any type of element.
type CondensedTableV = CondensedTable V.Vector



-- | Generate a random value using a condensed table.
genFromTable :: (StatefulGen g m, Vector v a) => CondensedTable v a -> g -> m a
{-# INLINE genFromTable #-}
genFromTable :: forall g (m :: * -> *) (v :: * -> *) a.
(StatefulGen g m, Vector v a) =>
CondensedTable v a -> g -> m a
genFromTable CondensedTable v a
table g
gen = do
  Word32
w <- g -> m Word32
forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
forall g (m :: * -> *). StatefulGen g m => g -> m Word32
uniformM g
gen
  a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$! CondensedTable v a -> Word64 -> a
forall (v :: * -> *) a.
Vector v a =>
CondensedTable v a -> Word64 -> a
lookupTable CondensedTable v a
table (Word64 -> a) -> Word64 -> a
forall a b. (a -> b) -> a -> b
$ Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w :: Word32)

lookupTable :: Vector v a => CondensedTable v a -> Word64 -> a
{-# INLINE lookupTable #-}
lookupTable :: forall (v :: * -> *) a.
Vector v a =>
CondensedTable v a -> Word64 -> a
lookupTable (CondensedTable Word64
na v a
aa Word64
nb v a
bb Word64
nc v a
cc v a
dd) Word64
i
  | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
na    = v a
aa v a -> Word64 -> a
forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ( Word64
i       Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
  | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
nb    = v a
bb v a -> Word64 -> a
forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ((Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
na) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
  | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
nc    = v a
cc v a -> Word64 -> a
forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ((Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
nb) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
8 )
  | Bool
otherwise = v a
dd v a -> Word64 -> a
forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ( Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
nc)
  where
    at :: v a -> a -> a
at v a
arr a
j = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
arr (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
j)


----------------------------------------------------------------
-- Table generation
----------------------------------------------------------------

-- | Generate a condensed lookup table from a list of outcomes with
-- given probabilities. The vector should be non-empty and the
-- probabilities should be non-negative and sum to 1. If this is not
-- the case, this algorithm will construct a table for some
-- distribution that may bear no resemblance to what you intended.
tableFromProbabilities
    :: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
       => v (a, Double) -> CondensedTable v a
{-# INLINE tableFromProbabilities #-}
tableFromProbabilities :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
 Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities v (a, Double)
v
  | v (a, Double) -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v (a, Double)
tbl = String -> String -> CondensedTable v a
forall a. String -> String -> a
pkgError String
"tableFromProbabilities" String
"empty vector of outcomes"
  | Bool
otherwise  = v (a, Word32) -> CondensedTable v a
forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v a, Vector v Word32) =>
v (a, Word32) -> CondensedTable v a
tableFromIntWeights (v (a, Word32) -> CondensedTable v a)
-> v (a, Word32) -> CondensedTable v a
forall a b. (a -> b) -> a -> b
$ ((a, Double) -> (a, Word32)) -> v (a, Double) -> v (a, Word32)
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((Double -> Word32) -> (a, Double) -> (a, Word32)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Double -> Word32) -> (a, Double) -> (a, Word32))
-> (Double -> Word32) -> (a, Double) -> (a, Word32)
forall a b. (a -> b) -> a -> b
$ Double -> Word32
forall {a}. Integral a => Double -> a
toWeight (Double -> Word32) -> (Double -> Double) -> Double -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
mlt)) v (a, Double)
tbl
  where
    -- 2^32. N.B. This number is exatly representable.
    mlt :: Double
mlt = Double
4.294967296e9
    -- Drop non-positive probabilities
    tbl :: v (a, Double)
tbl = ((a, Double) -> Bool) -> v (a, Double) -> v (a, Double)
forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd) v (a, Double)
v
    -- Convert Double weight to Word32 and avoid overflow at the same
    -- time. It's especially dangerous if one probability is
    -- approximately 1 and others are 0.
    toWeight :: Double -> a
toWeight Double
w | Double
w Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
mlt Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1 = a
2a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int) a -> a -> a
forall a. Num a => a -> a -> a
- a
1
               | Bool
otherwise   = Double -> a
forall {a}. Integral a => Double -> a
forall a b. (RealFrac a, Integral b) => a -> b
round Double
w


-- | Same as 'tableFromProbabilities' but treats number as weights not
-- probilities. Non-positive weights are discarded, and those
-- remaining are normalized to 1.
tableFromWeights
    :: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
       => v (a, Double) -> CondensedTable v a
{-# INLINE tableFromWeights #-}
tableFromWeights :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
 Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromWeights = v (a, Double) -> CondensedTable v a
forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
 Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities (v (a, Double) -> CondensedTable v a)
-> (v (a, Double) -> v (a, Double))
-> v (a, Double)
-> CondensedTable v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v (a, Double) -> v (a, Double)
forall {c} {v :: * -> *} {a}.
(Fractional c, Vector v (a, c)) =>
v (a, c) -> v (a, c)
normalize (v (a, Double) -> v (a, Double))
-> (v (a, Double) -> v (a, Double))
-> v (a, Double)
-> v (a, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Double) -> Bool) -> v (a, Double) -> v (a, Double)
forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0) (Double -> Bool) -> ((a, Double) -> Double) -> (a, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Double) -> Double
forall a b. (a, b) -> b
snd)
  where
    normalize :: v (a, c) -> v (a, c)
normalize v (a, c)
v
      | v (a, c) -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v (a, c)
v  = String -> String -> v (a, c)
forall a. String -> String -> a
pkgError String
"tableFromWeights" String
"no positive weights"
      | Bool
otherwise = ((a, c) -> (a, c)) -> v (a, c) -> v (a, c)
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((c -> c) -> (a, c) -> (a, c)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (c -> c -> c
forall a. Fractional a => a -> a -> a
/ c
s)) v (a, c)
v
      where
        -- Explicit fold is to avoid 'Vector v Double' constraint
        s :: c
s = (c -> (a, c) -> c) -> c -> v (a, c) -> c
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' (((a, c) -> c -> c) -> c -> (a, c) -> c
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((a, c) -> c -> c) -> c -> (a, c) -> c)
-> ((a, c) -> c -> c) -> c -> (a, c) -> c
forall a b. (a -> b) -> a -> b
$ c -> c -> c
forall a. Num a => a -> a -> a
(+) (c -> c -> c) -> ((a, c) -> c) -> (a, c) -> c -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, c) -> c
forall a b. (a, b) -> b
snd) c
0 v (a, c)
v


-- | Generate a condensed lookup table from integer weights. Weights
-- should sum to @2^32@ at least approximately. This function will
-- correct small deviations from @2^32@ such as arising from rounding
-- errors. But for large deviations it's likely to product incorrect
-- result with terrible performance.
tableFromIntWeights :: (Vector v (a,Word32), Vector v a, Vector v Word32)
                    => v (a, Word32)
                    -> CondensedTable v a
{-# INLINE tableFromIntWeights #-}
tableFromIntWeights :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v a, Vector v Word32) =>
v (a, Word32) -> CondensedTable v a
tableFromIntWeights v (a, Word32)
v
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = String -> String -> CondensedTable v a
forall a. String -> String -> a
pkgError String
"tableFromIntWeights" String
"empty table"
    -- Single element tables should be treated separately. Otherwise
    -- they will confuse correctWeights
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1    = let m :: Word64
m = Word64
2Word64 -> Int -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1 -- Works for both Word32 & Word64
                in Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
forall (v :: * -> *) a.
Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
CondensedTable
                   Word64
m (Int -> a -> v a
forall (v :: * -> *) a. Vector v a => Int -> a -> v a
G.replicate Int
256 (a -> v a) -> a -> v a
forall a b. (a -> b) -> a -> b
$ (a, Word32) -> a
forall a b. (a, b) -> a
fst ((a, Word32) -> a) -> (a, Word32) -> a
forall a b. (a -> b) -> a -> b
$ v (a, Word32) -> (a, Word32)
forall (v :: * -> *) a. Vector v a => v a -> a
G.head v (a, Word32)
tbl)
                   Word64
m  v a
forall (v :: * -> *) a. Vector v a => v a
G.empty
                   Word64
m  v a
forall (v :: * -> *) a. Vector v a => v a
G.empty
                      v a
forall (v :: * -> *) a. Vector v a => v a
G.empty
  | Bool
otherwise = Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
forall (v :: * -> *) a.
Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
CondensedTable
                Word64
na v a
aa
                Word64
nb v a
bb
                Word64
nc v a
cc
                   v a
dd
  where
    -- We must filter out zero-probability outcomes because they may
    -- confuse weight correction algorithm
    tbl :: v (a, Word32)
tbl   = ((a, Word32) -> Bool) -> v (a, Word32) -> v (a, Word32)
forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
/=Word32
0) (Word32 -> Bool) -> ((a, Word32) -> Word32) -> (a, Word32) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Word32) -> Word32
forall a b. (a, b) -> b
snd) v (a, Word32)
v
    n :: Int
n     = v (a, Word32) -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v (a, Word32)
tbl
    -- Corrected table
    table :: v (a, Word32)
table = (v a -> v Word32 -> v (a, Word32))
-> (v a, v Word32) -> v (a, Word32)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry v a -> v Word32 -> v (a, Word32)
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v a -> v b -> v (a, b)
G.zip ((v a, v Word32) -> v (a, Word32))
-> (v a, v Word32) -> v (a, Word32)
forall a b. (a -> b) -> a -> b
$ v a -> v a
forall a. a -> a
id (v a -> v a)
-> (v Word32 -> v Word32) -> (v a, v Word32) -> (v a, v Word32)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** v Word32 -> v Word32
forall (v :: * -> *). Vector v Word32 => v Word32 -> v Word32
correctWeights ((v a, v Word32) -> (v a, v Word32))
-> (v a, v Word32) -> (v a, v Word32)
forall a b. (a -> b) -> a -> b
$ v (a, Word32) -> (v a, v Word32)
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v (a, b) -> (v a, v b)
G.unzip v (a, Word32)
tbl
    -- Make condensed table
    mkTable :: Int -> v a
mkTable  Int
d =
      ((a, Word32) -> v a) -> v (a, Word32) -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> v b) -> v a -> v b
G.concatMap (\(a
x,Word32
w) -> Int -> a -> v a
forall (v :: * -> *) a. Vector v a => Int -> a -> v a
G.replicate (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Word32 -> Word32
digit Int
d Word32
w) a
x) v (a, Word32)
table
    len :: v a -> Word64
len = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> (v a -> Int) -> v a -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length
    -- Tables
    aa :: v a
aa = Int -> v a
mkTable Int
0
    bb :: v a
bb = Int -> v a
mkTable Int
1
    cc :: v a
cc = Int -> v a
mkTable Int
2
    dd :: v a
dd = Int -> v a
mkTable Int
3
    -- Offsets
    na :: Word64
na =       v a -> Word64
len v a
aa Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
24
    nb :: Word64
nb = Word64
na Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (v a -> Word64
len v a
bb Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
16)
    nc :: Word64
nc = Word64
nb Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (v a -> Word64
len v a
cc Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
8)


-- Calculate N'th digit base 256
digit :: Int -> Word32 -> Word32
digit :: Int -> Word32 -> Word32
digit Int
0 Word32
x =  Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24
digit Int
1 Word32
x = (Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
2 Word32
x = (Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8 ) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
3 Word32
x =  Word32
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
_ Word32
_ = String -> String -> Word32
forall a. String -> String -> a
pkgError String
"digit" String
"the impossible happened!?"
{-# INLINE digit #-}

-- Correct integer weights so they sum up to 2^32. Array of weight
-- should contain at least 2 elements.
correctWeights :: G.Vector v Word32 => v Word32 -> v Word32
{-# INLINE correctWeights #-}
correctWeights :: forall (v :: * -> *). Vector v Word32 => v Word32 -> v Word32
correctWeights v Word32
v = (forall s. ST s (Mutable v s Word32)) -> v Word32
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create ((forall s. ST s (Mutable v s Word32)) -> v Word32)
-> (forall s. ST s (Mutable v s Word32)) -> v Word32
forall a b. (a -> b) -> a -> b
$ do
  let
    -- Sum of weights
    s :: Int64
s = (Int64 -> Word32 -> Int64) -> Int64 -> v Word32 -> Int64
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' ((Word32 -> Int64 -> Int64) -> Int64 -> Word32 -> Int64
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Word32 -> Int64 -> Int64) -> Int64 -> Word32 -> Int64)
-> (Word32 -> Int64 -> Int64) -> Int64 -> Word32 -> Int64
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
(+) (Int64 -> Int64 -> Int64)
-> (Word32 -> Int64) -> Word32 -> Int64 -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral) Int64
0 v Word32
v :: Int64
    -- Array size
    n :: Int
n = v Word32 -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v Word32
v
  Mutable v (PrimState (ST s)) Word32
arr <- v Word32 -> ST s (Mutable v (PrimState (ST s)) Word32)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v Word32
v
  -- On first pass over array adjust only entries which are larger
  -- than `lim'. On second and subsequent passes `lim' is set to 1.
  --
  -- It's possibly to make this algorithm loop endlessly if all
  -- weights are 1 or 0.
  let loop :: Word32 -> Int -> a -> ST s ()
loop Word32
lim Int
i a
delta
        | a
delta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n     = Word32 -> Int -> a -> ST s ()
loop Word32
1 Int
0 a
delta
        | Bool
otherwise  = do
            Word32
w <- Mutable v (PrimState (ST s)) Word32 -> Int -> ST s Word32
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read Mutable v (PrimState (ST s)) Word32
arr Int
i
            case () of
              ()
_| Word32
w Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
< Word32
lim   -> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) a
delta
               | a
delta a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 -> Mutable v (PrimState (ST s)) Word32 -> Int -> Word32 -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write Mutable v (PrimState (ST s)) Word32
arr Int
i (Word32
w Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
1) ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
delta a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
               | Bool
otherwise -> Mutable v (PrimState (ST s)) Word32 -> Int -> Word32 -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write Mutable v (PrimState (ST s)) Word32
arr Int
i (Word32
w Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
1) ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
delta a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
  Word32 -> Int -> Int64 -> ST s ()
forall {a}. (Num a, Ord a) => Word32 -> Int -> a -> ST s ()
loop Word32
255 Int
0 (Int64
s Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
2Int64 -> Int -> Int64
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int))
  Mutable v s Word32 -> ST s (Mutable v s Word32)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s Word32
Mutable v (PrimState (ST s)) Word32
arr


-- | Create a lookup table for the Poisson distribution. Note that
-- table construction may have significant cost. For λ < 100 it
-- takes as much time to build table as generation of 1000-30000
-- variates.
tablePoisson :: Double -> CondensedTableU Int
tablePoisson :: Double -> CondensedTableU Int
tablePoisson = Vector (Int, Double) -> CondensedTableU Int
forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
 Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities (Vector (Int, Double) -> CondensedTableU Int)
-> (Double -> Vector (Int, Double))
-> Double
-> CondensedTableU Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Vector (Int, Double)
make
  where
    make :: Double -> Vector (Int, Double)
make Double
lam
      | Double
lam Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0    = String -> String -> Vector (Int, Double)
forall a. String -> String -> a
pkgError String
"tablePoisson" String
"negative lambda"
      | Double
lam Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
22.8 = ((Double, Int) -> Maybe ((Int, Double), (Double, Int)))
-> (Double, Int) -> Vector (Int, Double)
forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr (Double, Int) -> Maybe ((Int, Double), (Double, Int))
forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (Double -> Double
forall a. Floating a => a -> a
exp (-Double
lam), Int
0)
      | Bool
otherwise  = ((Double, Int) -> Maybe ((Int, Double), (Double, Int)))
-> (Double, Int) -> Vector (Int, Double)
forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr (Double, Int) -> Maybe ((Int, Double), (Double, Int))
forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (Double
pMax, Int
nMax)
                  Vector (Int, Double)
-> Vector (Int, Double) -> Vector (Int, Double)
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
++ Vector (Int, Double) -> Vector (Int, Double)
forall a. Unbox a => Vector a -> Vector a
U.tail (((Double, Int) -> Maybe ((Int, Double), (Double, Int)))
-> (Double, Int) -> Vector (Int, Double)
forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr (Double, Int) -> Maybe ((Int, Double), (Double, Int))
forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldBackward (Double
pMax, Int
nMax))
      where
        -- Number with highest probability and its probability
        --
        -- FIXME: this is not ideal precision-wise. Check if code
        --        from statistics gives better precision.
        nMax :: Int
nMax = Double -> Int
forall {a}. Integral a => Double -> a
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
lam :: Int
        pMax :: Double
pMax = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMax Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
lam Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lam Double -> Double -> Double
forall a. Num a => a -> a -> a
- Int -> Double
forall a. Integral a => a -> Double
logFactorial Int
nMax
        -- Build probability list
        unfoldForward :: (Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (Double
p,b
i)
          | Double
p Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
minP  = Maybe ((b, Double), (Double, b))
forall a. Maybe a
Nothing
          | Bool
otherwise = ((b, Double), (Double, b)) -> Maybe ((b, Double), (Double, b))
forall a. a -> Maybe a
Just ( (b
i,Double
p)
                             , (Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
lam Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1), b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1)
                             )
        -- Go down
        unfoldBackward :: (Double, b) -> Maybe ((b, Double), (Double, b))
unfoldBackward (Double
p,b
i)
          | Double
p Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
minP  = Maybe ((b, Double), (Double, b))
forall a. Maybe a
Nothing
          | Bool
otherwise = ((b, Double), (Double, b)) -> Maybe ((b, Double), (Double, b))
forall a. a -> Maybe a
Just ( (b
i,Double
p)
                             , (Double
p Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
lam Double -> Double -> Double
forall a. Num a => a -> a -> a
* b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
i, b
ib -> b -> b
forall a. Num a => a -> a -> a
-b
1)
                             )
    -- Minimal representable probability for condensed tables
    minP :: Double
minP = Double
1.1641532182693481e-10 -- 2**(-33)

-- | Create a lookup table for the binomial distribution.
tableBinomial :: Int            -- ^ Number of tries
              -> Double         -- ^ Probability of success
              -> CondensedTableU Int
tableBinomial :: Int -> Double -> CondensedTableU Int
tableBinomial Int
n Double
p = Vector (Int, Double) -> CondensedTableU Int
forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
 Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities Vector (Int, Double)
makeBinom
  where 
  makeBinom :: Vector (Int, Double)
makeBinom
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0         = String -> String -> Vector (Int, Double)
forall a. String -> String -> a
pkgError String
"tableBinomial" String
"non-positive number of tries"
    | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0         = (Int, Double) -> Vector (Int, Double)
forall a. Unbox a => a -> Vector a
U.singleton (Int
0,Double
1)
    | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
1         = (Int, Double) -> Vector (Int, Double)
forall a. Unbox a => a -> Vector a
U.singleton (Int
n,Double
1)
    | Double
p Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 Bool -> Bool -> Bool
&& Double
p Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 = Int
-> ((Double, Int) -> Maybe ((Int, Double), (Double, Int)))
-> (Double, Int)
-> Vector (Int, Double)
forall a b. Unbox a => Int -> (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldrN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Double, Int) -> Maybe ((Int, Double), (Double, Int))
unfolder ((Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
p)Double -> Int -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Int
n, Int
0)
    | Bool
otherwise      = String -> String -> Vector (Int, Double)
forall a. String -> String -> a
pkgError String
"tableBinomial" String
"probability is out of range"
    where
      h :: Double
h = Double
p Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p)
      unfolder :: (Double, Int) -> Maybe ((Int, Double), (Double, Int))
unfolder (Double
t,Int
i) = ((Int, Double), (Double, Int))
-> Maybe ((Int, Double), (Double, Int))
forall a. a -> Maybe a
Just ( (Int
i,Double
t)
                            , (Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
h Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i1, Int
i1) )
        where i1 :: Int
i1 = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

pkgError :: String -> String -> a
pkgError :: forall a. String -> String -> a
pkgError String
func String
err =
    String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> ([String] -> String) -> [String] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([String] -> a) -> [String] -> a
forall a b. (a -> b) -> a -> b
$ [String
"System.Random.MWC.CondensedTable.", String
func, String
": ", String
err]

-- $references
--
-- * Wang, J.; Tsang, W. W.; G. Marsaglia (2004), Fast Generation of
--   Discrete Random Variables, /Journal of Statistical Software,
--   American Statistical Association/, vol. 11(i03).
--   <http://ideas.repec.org/a/jss/jstsof/11i03.html>