module OpenSSL.BN
(
BigNum
, BIGNUM
, allocaBN
, withBN
, newBN
, wrapBN
, unwrapBN
, peekBN
, integerToBN
, bnToInteger
, integerToMPI
, mpiToInteger
, modexp
, randIntegerUptoNMinusOneSuchThat
, prandIntegerUptoNMinusOneSuchThat
, randIntegerZeroToNMinusOne
, prandIntegerZeroToNMinusOne
, randIntegerOneToNMinusOne
, prandIntegerOneToNMinusOne
)
where
import Control.Exception hiding (try)
import qualified Data.ByteString as BS
import Foreign.Marshal
import Foreign.Ptr
import Foreign.Storable
import OpenSSL.Utils
import System.IO.Unsafe
import Foreign.C.Types
import GHC.Base
import GHC.Integer.GMP.Internals
newtype BigNum = BigNum (Ptr BIGNUM)
data BIGNUM
foreign import ccall unsafe "BN_new"
_new :: IO (Ptr BIGNUM)
foreign import ccall unsafe "BN_free"
_free :: Ptr BIGNUM -> IO ()
allocaBN :: (BigNum -> IO a) -> IO a
allocaBN m
= bracket _new _free (m . wrapBN)
unwrapBN :: BigNum -> Ptr BIGNUM
unwrapBN (BigNum p) = p
wrapBN :: Ptr BIGNUM -> BigNum
wrapBN = BigNum
foreign import ccall unsafe "memcpy"
_copy_in :: ByteArray# -> Ptr () -> CSize -> IO (Ptr ())
foreign import ccall unsafe "memcpy"
_copy_out :: Ptr () -> ByteArray# -> CSize -> IO (Ptr ())
data ByteArray = BA !ByteArray#
data MBA = MBA !(MutableByteArray# RealWorld)
newByteArray :: Int# -> IO MBA
newByteArray sz = IO $ \s ->
case newByteArray# sz s of { (# s', arr #) ->
(# s', MBA arr #) }
freezeByteArray :: MutableByteArray# RealWorld -> IO ByteArray
freezeByteArray arr = IO $ \s ->
case unsafeFreezeByteArray# arr s of { (# s', arr' #) ->
(# s', BA arr' #) }
bnToInteger :: BigNum -> IO Integer
bnToInteger bn = do
nlimbs <- ((\hsc_ptr -> peekByteOff hsc_ptr 8)) (unwrapBN bn) :: IO CInt
case nlimbs of
0 -> return 0
1 -> do (I# i) <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) (unwrapBN bn) >>= peek
negative <- ((\hsc_ptr -> peekByteOff hsc_ptr 16)) (unwrapBN bn) :: IO CInt
if negative == 0
then return $ S# i
else return $ S# (0# -# i)
_ -> do
let !(I# nlimbsi) = fromIntegral nlimbs
!(I# limbsize) = ((8))
(MBA arr) <- newByteArray (nlimbsi *# limbsize)
(BA ba) <- freezeByteArray arr
limbs <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) (unwrapBN bn)
_ <- _copy_in ba limbs $ fromIntegral $ nlimbs * ((8))
negative <- ((\hsc_ptr -> peekByteOff hsc_ptr 16)) (unwrapBN bn) :: IO CInt
if negative == 0
then return $ Jp# (byteArrayToBigNat# ba nlimbsi)
else return $ Jn# (byteArrayToBigNat# ba nlimbsi)
integerToBN :: Integer -> IO BigNum
integerToBN (S# 0#) = do
bnptr <- mallocBytes ((24))
((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr nullPtr
let one :: CInt
one = 1
zero :: CInt
zero = 0
((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr one
((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr zero
((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr zero
((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr zero
return (wrapBN bnptr)
integerToBN (S# v) = do
bnptr <- mallocBytes ((24))
limbs <- malloc :: IO (Ptr CULong)
poke limbs $ fromIntegral $ abs $ I# v
((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr limbs
let one :: CInt
one = 1
((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr one
((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr one
((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr one
((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr (if (I# v) < 0 then one else 0)
return (wrapBN bnptr)
integerToBN v =
case v of
Jp# bn -> convert 0 bn
Jn# bn -> convert 1 bn
S# _ -> undefined
where
convert :: CInt -> BigNat -> IO BigNum
convert negValue bn@(BN# bytearray) = do
let nlimbs = I# (sizeofBigNat# bn)
bnptr <- mallocBytes ((24))
limbs <- mallocBytes (((8)) * nlimbs)
((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr limbs
((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr (1 :: CInt)
_ <- _copy_out limbs bytearray (fromIntegral $ ((8)) * nlimbs)
((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr ((fromIntegral nlimbs) :: CInt)
((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr ((fromIntegral nlimbs) :: CInt)
((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr negValue
return (wrapBN bnptr)
withBN :: Integer -> (BigNum -> IO a) -> IO a
withBN dec m = bracket (integerToBN dec) (_free . unwrapBN) m
foreign import ccall unsafe "BN_bn2mpi"
_bn2mpi :: Ptr BIGNUM -> Ptr CChar -> IO CInt
foreign import ccall unsafe "BN_mpi2bn"
_mpi2bn :: Ptr CChar -> CInt -> Ptr BIGNUM -> IO (Ptr BIGNUM)
peekBN :: BigNum -> IO Integer
peekBN = bnToInteger
newBN :: Integer -> IO BigNum
newBN = integerToBN
bnToMPI :: BigNum -> IO BS.ByteString
bnToMPI bn = do
bytes <- _bn2mpi (unwrapBN bn) nullPtr
allocaBytes (fromIntegral bytes) (\buffer -> do
_ <- _bn2mpi (unwrapBN bn) buffer
BS.packCStringLen (buffer, fromIntegral bytes))
mpiToBN :: BS.ByteString -> IO BigNum
mpiToBN mpi = do
BS.useAsCStringLen mpi (\(ptr, len) -> do
_mpi2bn ptr (fromIntegral len) nullPtr) >>= return . wrapBN
integerToMPI :: Integer -> IO BS.ByteString
integerToMPI v = bracket (integerToBN v) (_free . unwrapBN) bnToMPI
mpiToInteger :: BS.ByteString -> IO Integer
mpiToInteger mpi = do
bn <- mpiToBN mpi
v <- bnToInteger bn
_free (unwrapBN bn)
return v
foreign import ccall unsafe "BN_mod_exp"
_mod_exp :: Ptr BIGNUM -> Ptr BIGNUM -> Ptr BIGNUM -> Ptr BIGNUM -> BNCtx -> IO (Ptr BIGNUM)
type BNCtx = Ptr BNCTX
data BNCTX
foreign import ccall unsafe "BN_CTX_new"
_BN_ctx_new :: IO BNCtx
foreign import ccall unsafe "BN_CTX_free"
_BN_ctx_free :: BNCtx -> IO ()
withBNCtx :: (BNCtx -> IO a) -> IO a
withBNCtx f = bracket _BN_ctx_new _BN_ctx_free f
modexp :: Integer -> Integer -> Integer -> Integer
modexp a p m = unsafePerformIO (do
withBN a (\bnA -> (do
withBN p (\bnP -> (do
withBN m (\bnM -> (do
withBNCtx (\ctx -> (do
r <- newBN 0
_ <- _mod_exp (unwrapBN r) (unwrapBN bnA) (unwrapBN bnP) (unwrapBN bnM) ctx
bnToInteger r >>= return)))))))))
foreign import ccall unsafe "BN_rand_range"
_BN_rand_range :: Ptr BIGNUM -> Ptr BIGNUM -> IO CInt
foreign import ccall unsafe "BN_pseudo_rand_range"
_BN_pseudo_rand_range :: Ptr BIGNUM -> Ptr BIGNUM -> IO CInt
randIntegerUptoNMinusOneSuchThat :: (Integer -> Bool)
-> Integer
-> IO Integer
randIntegerUptoNMinusOneSuchThat f range = withBN range (\bnRange -> (do
r <- newBN 0
let try = do
_BN_rand_range (unwrapBN r) (unwrapBN bnRange) >>= failIf_ (/= 1)
i <- bnToInteger r
if f i
then return i
else try
try))
prandIntegerUptoNMinusOneSuchThat :: (Integer -> Bool)
-> Integer
-> IO Integer
prandIntegerUptoNMinusOneSuchThat f range = withBN range (\bnRange -> (do
r <- newBN 0
let try = do
_BN_pseudo_rand_range (unwrapBN r) (unwrapBN bnRange) >>= failIf_ (/= 1)
i <- bnToInteger r
if f i
then return i
else try
try))
randIntegerZeroToNMinusOne :: Integer -> IO Integer
randIntegerZeroToNMinusOne = randIntegerUptoNMinusOneSuchThat (const True)
randIntegerOneToNMinusOne :: Integer -> IO Integer
randIntegerOneToNMinusOne = randIntegerUptoNMinusOneSuchThat (/= 0)
prandIntegerZeroToNMinusOne :: Integer -> IO Integer
prandIntegerZeroToNMinusOne = prandIntegerUptoNMinusOneSuchThat (const True)
prandIntegerOneToNMinusOne :: Integer -> IO Integer
prandIntegerOneToNMinusOne = prandIntegerUptoNMinusOneSuchThat (/= 0)