{-# LANGUAGE FlexibleContexts #-}
module System.Random.MWC.CondensedTable (
CondensedTable
, CondensedTableV
, CondensedTableU
, genFromTable
, tableFromProbabilities
, tableFromWeights
, tableFromIntWeights
, tablePoisson
, tableBinomial
) where
import Control.Arrow (second,(***))
import Control.Monad.Primitive (PrimMonad(..))
import Data.Word
import Data.Int
import Data.Bits
import qualified Data.Vector.Generic as G
import Data.Vector.Generic ((++))
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import Data.Vector.Generic (Vector)
import Numeric.SpecFunctions (logFactorial)
import Prelude hiding ((++))
import System.Random.MWC
data CondensedTable v a =
CondensedTable
{-# UNPACK #-} !Word64 !(v a)
{-# UNPACK #-} !Word64 !(v a)
{-# UNPACK #-} !Word64 !(v a)
!(v a)
type CondensedTableU = CondensedTable U.Vector
type CondensedTableV = CondensedTable V.Vector
genFromTable :: (PrimMonad m, Vector v a) =>
CondensedTable v a -> Gen (PrimState m) -> m a
{-# INLINE genFromTable #-}
genFromTable table gen = do
w <- uniform gen
return $ lookupTable table $ fromIntegral (w :: Word32)
lookupTable :: Vector v a => CondensedTable v a -> Word64 -> a
{-# INLINE lookupTable #-}
lookupTable (CondensedTable na aa nb bb nc cc dd) i
| i < na = aa `at` ( i `shiftR` 24)
| i < nb = bb `at` ((i - na) `shiftR` 16)
| i < nc = cc `at` ((i - nb) `shiftR` 8 )
| otherwise = dd `at` ( i - nc)
where
at arr j = G.unsafeIndex arr (fromIntegral j)
tableFromProbabilities
:: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
=> v (a, Double) -> CondensedTable v a
{-# INLINE tableFromProbabilities #-}
tableFromProbabilities v
| G.null tbl = pkgError "tableFromProbabilities" "empty vector of outcomes"
| otherwise = tableFromIntWeights $ G.map (second $ toWeight . (* mlt)) tbl
where
mlt = 4.294967296e9
tbl = G.filter ((> 0) . snd) v
toWeight w | w > mlt - 1 = 2^(32::Int) - 1
| otherwise = round w
tableFromWeights
:: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
=> v (a, Double) -> CondensedTable v a
{-# INLINE tableFromWeights #-}
tableFromWeights = tableFromProbabilities . normalize . G.filter ((> 0) . snd)
where
normalize v
| G.null v = pkgError "tableFromWeights" "no positive weights"
| otherwise = G.map (second (/ s)) v
where
s = G.foldl' (flip $ (+) . snd) 0 v
tableFromIntWeights :: (Vector v (a,Word32), Vector v a, Vector v Word32)
=> v (a, Word32)
-> CondensedTable v a
{-# INLINE tableFromIntWeights #-}
tableFromIntWeights v
| n == 0 = pkgError "tableFromIntWeights" "empty table"
| n == 1 = let m = 2^(32::Int) - 1
in CondensedTable
m (G.replicate 256 $ fst $ G.head tbl)
m G.empty
m G.empty
G.empty
| otherwise = CondensedTable
na aa
nb bb
nc cc
dd
where
tbl = G.filter ((/=0) . snd) v
n = G.length tbl
table = uncurry G.zip $ id *** correctWeights $ G.unzip tbl
mkTable d =
G.concatMap (\(x,w) -> G.replicate (fromIntegral $ digit d w) x) table
len = fromIntegral . G.length
aa = mkTable 0
bb = mkTable 1
cc = mkTable 2
dd = mkTable 3
na = len aa `shiftL` 24
nb = na + (len bb `shiftL` 16)
nc = nb + (len cc `shiftL` 8)
digit :: Int -> Word32 -> Word32
digit 0 x = x `shiftR` 24
digit 1 x = (x `shiftR` 16) .&. 0xff
digit 2 x = (x `shiftR` 8 ) .&. 0xff
digit 3 x = x .&. 0xff
digit _ _ = pkgError "digit" "the impossible happened!?"
{-# INLINE digit #-}
correctWeights :: G.Vector v Word32 => v Word32 -> v Word32
{-# INLINE correctWeights #-}
correctWeights v = G.create $ do
let
s = G.foldl' (flip $ (+) . fromIntegral) 0 v :: Int64
n = G.length v
arr <- G.thaw v
let loop lim i delta
| delta == 0 = return ()
| i >= n = loop 1 0 delta
| otherwise = do
w <- M.read arr i
case () of
_| w < lim -> loop lim (i+1) delta
| delta < 0 -> M.write arr i (w + 1) >> loop lim (i+1) (delta + 1)
| otherwise -> M.write arr i (w - 1) >> loop lim (i+1) (delta - 1)
loop 255 0 (s - 2^(32::Int))
return arr
tablePoisson :: Double -> CondensedTableU Int
tablePoisson = tableFromProbabilities . make
where
make lam
| lam < 0 = pkgError "tablePoisson" "negative lambda"
| lam < 22.8 = U.unfoldr unfoldForward (exp (-lam), 0)
| otherwise = U.unfoldr unfoldForward (pMax, nMax)
++ U.tail (U.unfoldr unfoldBackward (pMax, nMax))
where
nMax = floor lam :: Int
pMax = exp $ fromIntegral nMax * log lam - lam - logFactorial nMax
unfoldForward (p,i)
| p < minP = Nothing
| otherwise = Just ( (i,p)
, (p * lam / fromIntegral (i+1), i+1)
)
unfoldBackward (p,i)
| p < minP = Nothing
| otherwise = Just ( (i,p)
, (p / lam * fromIntegral i, i-1)
)
minP = 1.1641532182693481e-10
tableBinomial :: Int
-> Double
-> CondensedTableU Int
tableBinomial n p = tableFromProbabilities makeBinom
where
makeBinom
| n <= 0 = pkgError "tableBinomial" "non-positive number of tries"
| p == 0 = U.singleton (0,1)
| p == 1 = U.singleton (n,1)
| p > 0 && p < 1 = U.unfoldrN (n + 1) unfolder ((1-p)^n, 0)
| otherwise = pkgError "tableBinomial" "probability is out of range"
where
h = p / (1 - p)
unfolder (t,i) = Just ( (i,t)
, (t * (fromIntegral $ n + 1 - i1) * h / fromIntegral i1, i1) )
where i1 = i + 1
pkgError :: String -> String -> a
pkgError func err =
error . concat $ ["System.Random.MWC.CondensedTable.", func, ": ", err]