{-# LANGUAGE CPP #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ < 710
{-# LANGUAGE OverlappingInstances #-}
#endif
module Crypto.RNG (
module Crypto.RNG.Class
, CryptoRNGState
, newCryptoRNGState
, unsafeCryptoRNGState
, randomBytesIO
, randomR
, Random(..)
, boundedIntegralRandom
, CryptoRNGT
, mapCryptoRNGT
, runCryptoRNGT
, withCryptoRNGState
) where
import Control.Applicative
import Control.Concurrent
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Cont
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.Trans.Control
import Crypto.Random
import Crypto.Random.DRBG
import Data.Bits
import Data.ByteString (ByteString, unpack)
import Data.Int
import Data.List
import Data.Word
import Crypto.RNG.Class
newtype CryptoRNGState = CryptoRNGState (MVar (GenAutoReseed HashDRBG HashDRBG))
newCryptoRNGState :: MonadIO m => m CryptoRNGState
newCryptoRNGState = liftIO $ newGenIO >>= fmap CryptoRNGState . newMVar
unsafeCryptoRNGState :: MonadIO m => ByteString -> m CryptoRNGState
unsafeCryptoRNGState s = liftIO $
either (fail . show) (fmap CryptoRNGState . newMVar) (newGen s)
randomBytesIO :: ByteLength
-> CryptoRNGState
-> IO ByteString
randomBytesIO n (CryptoRNGState gv) = do
liftIO $ modifyMVar gv $ \g -> do
(bs, g') <- either (fail "Crypto.GlobalRandom.genBytes") return $
genBytes n g
return (g', bs)
randomR :: (CryptoRNG m, Integral a) => (a, a) -> m a
randomR (minb', maxb') = do
bs <- randomBytes byteLen
return . fromIntegral $
minb + foldl1' (\r a -> shiftL r 8 .|. a) (map toInteger (unpack bs))
`mod` range
where
minb, maxb, range :: Integer
minb = fromIntegral minb'
maxb = fromIntegral maxb'
range = maxb - minb + 1
byteLen = ceiling $ logBase 2 (fromIntegral range) / (8 :: Double)
boundedIntegralRandom :: forall m a. (CryptoRNG m, Integral a, Bounded a) => m a
boundedIntegralRandom = randomR (minBound :: a, maxBound :: a)
class Random a where
random :: CryptoRNG m => m a
instance Random Int16 where
random = boundedIntegralRandom
instance Random Int32 where
random = boundedIntegralRandom
instance Random Int64 where
random = boundedIntegralRandom
instance Random Int where
random = boundedIntegralRandom
instance Random Word8 where
random = boundedIntegralRandom
instance Random Word16 where
random = boundedIntegralRandom
instance Random Word32 where
random = boundedIntegralRandom
instance Random Word64 where
random = boundedIntegralRandom
instance Random Word where
random = boundedIntegralRandom
type InnerCryptoRNGT = ReaderT CryptoRNGState
newtype CryptoRNGT m a = CryptoRNGT { unCryptoRNGT :: InnerCryptoRNGT m a }
deriving ( Alternative, Applicative, Functor, Monad
, MonadBase b, MonadCatch, MonadError e, MonadIO, MonadMask, MonadPlus
, MonadThrow, MonadTrans )
mapCryptoRNGT :: (m a -> n b) -> CryptoRNGT m a -> CryptoRNGT n b
mapCryptoRNGT f m = withCryptoRNGState $ \s -> f (runCryptoRNGT s m)
runCryptoRNGT :: CryptoRNGState -> CryptoRNGT m a -> m a
runCryptoRNGT gv m = runReaderT (unCryptoRNGT m) gv
withCryptoRNGState :: (CryptoRNGState -> m a) -> CryptoRNGT m a
withCryptoRNGState = CryptoRNGT . ReaderT
instance MonadTransControl CryptoRNGT where
type StT CryptoRNGT a = StT InnerCryptoRNGT a
liftWith = defaultLiftWith CryptoRNGT unCryptoRNGT
restoreT = defaultRestoreT CryptoRNGT
{-# INLINE liftWith #-}
{-# INLINE restoreT #-}
instance MonadBaseControl b m => MonadBaseControl b (CryptoRNGT m) where
type StM (CryptoRNGT m) a = ComposeSt CryptoRNGT m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
{-# INLINE liftBaseWith #-}
{-# INLINE restoreM #-}
instance {-# OVERLAPPABLE #-} MonadIO m => CryptoRNG (CryptoRNGT m) where
randomBytes n = CryptoRNGT ask >>= liftIO . randomBytesIO n