module Data.Array.Accelerate.TypeLits.System.Random.MWC
( rndMatrixWith
, rndVectorWith
, module Distributions
) where
import Data.Array.Accelerate.TypeLits.Internal
import GHC.TypeLits
import Data.Proxy
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Elt, Z(..), (:.)(..))
import Data.Array.Accelerate.System.Random.MWC
import System.Random.MWC.Distributions as Distributions
rndMatrixWith :: forall m n e. (KnownNat m, KnownNat n, Elt e) => (GenIO -> IO e) -> IO (AccMatrix m n e)
rndMatrixWith cdf = do r <- randomArray (const cdf) sh
return $ AccMatrix $ A.use r
where m' = fromInteger $ natVal (Proxy :: Proxy m)
n' = fromInteger $ natVal (Proxy :: Proxy n)
sh = Z:.m':.n'
rndVectorWith :: forall n e. (KnownNat n, Elt e) => (GenIO -> IO e) -> IO (AccVector n e)
rndVectorWith cdf = do r <- randomArray (const cdf) sh
return $ AccVector $ A.use r
where n' = fromInteger $ natVal (Proxy :: Proxy n)
sh = Z:.n'