{-# LINE 1 "OpenSSL/BN.hsc" #-}

{-# LINE 2 "OpenSSL/BN.hsc" #-}
{-# LANGUAGE BangPatterns             #-}

{-# LINE 4 "OpenSSL/BN.hsc" #-}
{-# LANGUAGE EmptyDataDecls           #-}
{-# LANGUAGE ForeignFunctionInterface #-}

{-# LINE 7 "OpenSSL/BN.hsc" #-}
{-# LANGUAGE MagicHash                #-}
{-# LANGUAGE UnboxedTuples            #-}
{-# LANGUAGE UnliftedFFITypes         #-}

{-# LINE 11 "OpenSSL/BN.hsc" #-}
{-# OPTIONS_HADDOCK prune             #-}
-- |BN - multiprecision integer arithmetics
module OpenSSL.BN
    ( -- * Type
      BigNum
    , BIGNUM

      -- * Allocation
    , allocaBN
    , withBN

    , newBN
    , wrapBN -- private
    , unwrapBN -- private

      -- * Conversion from\/to Integer
    , peekBN
    , integerToBN
    , bnToInteger
    , integerToMPI
    , mpiToInteger

      -- * Computation
    , modexp

      -- * Random number generation
    , randIntegerUptoNMinusOneSuchThat
    , prandIntegerUptoNMinusOneSuchThat
    , randIntegerZeroToNMinusOne
    , prandIntegerZeroToNMinusOne
    , randIntegerOneToNMinusOne
    , prandIntegerOneToNMinusOne
    )
    where

{-# LINE 46 "OpenSSL/BN.hsc" #-}
import           Control.Exception hiding (try)
import qualified Data.ByteString as BS
import           Foreign.Marshal
import           Foreign.Ptr
import           Foreign.Storable
import           OpenSSL.Utils
import           System.IO.Unsafe


{-# LINE 55 "OpenSSL/BN.hsc" #-}
import           Foreign.C.Types
import           GHC.Base

{-# LINE 58 "OpenSSL/BN.hsc" #-}
import           GHC.Integer.GMP.Internals

{-# LINE 65 "OpenSSL/BN.hsc" #-}

{-# LINE 69 "OpenSSL/BN.hsc" #-}

-- |'BigNum' is an opaque object representing a big number.
newtype BigNum = BigNum (Ptr BIGNUM)
data BIGNUM


foreign import ccall unsafe "BN_new"
        _new :: IO (Ptr BIGNUM)

foreign import ccall unsafe "BN_free"
        _free :: Ptr BIGNUM -> IO ()

-- |@'allocaBN' f@ allocates a 'BigNum' and computes @f@. Then it
-- frees the 'BigNum'.
allocaBN :: (BigNum -> IO a) -> IO a
allocaBN m
    = bracket _new _free (m . wrapBN)


unwrapBN :: BigNum -> Ptr BIGNUM
unwrapBN (BigNum p) = p


wrapBN :: Ptr BIGNUM -> BigNum
wrapBN = BigNum



{-# LINE 128 "OpenSSL/BN.hsc" #-}

{- fast, dangerous functions ------------------------------------------------ -}

-- Both BN (the OpenSSL library) and GMP (used by GHC) use the same internal
-- representation for numbers: an array of words, least-significant first. Thus
-- we can move from Integer's to BIGNUMs very quickly: by copying in the worst
-- case and by just alloca'ing and pointing into the Integer in the fast case.
-- Note that, in the fast case, it's very important that any foreign function
-- calls be "unsafe", that is, they don't call back into Haskell. Otherwise the
-- GC could do nasty things to the data which we thought that we had a pointer
-- to

foreign import ccall unsafe "memcpy"
        _copy_in :: ByteArray# -> Ptr () -> CSize -> IO (Ptr ())

foreign import ccall unsafe "memcpy"
        _copy_out :: Ptr () -> ByteArray# -> CSize -> IO (Ptr ())

-- These are taken from Data.Binary's disabled fast Integer support
data ByteArray = BA  !ByteArray#
data MBA       = MBA !(MutableByteArray# RealWorld)

newByteArray :: Int# -> IO MBA
newByteArray sz = IO $ \s ->
  case newByteArray# sz s of { (# s', arr #) ->
  (# s', MBA arr #) }

freezeByteArray :: MutableByteArray# RealWorld -> IO ByteArray
freezeByteArray arr = IO $ \s ->
  case unsafeFreezeByteArray# arr s of { (# s', arr' #) ->
  (# s', BA arr' #) }

-- | Convert a BIGNUM to an Integer
bnToInteger :: BigNum -> IO Integer
bnToInteger bn = do
  nlimbs <- ((\hsc_ptr -> peekByteOff hsc_ptr 8)) (unwrapBN bn) :: IO CInt
{-# LINE 164 "OpenSSL/BN.hsc" #-}
  case nlimbs of
    0 -> return 0
    1 -> do (I# i) <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) (unwrapBN bn) >>= peek
{-# LINE 167 "OpenSSL/BN.hsc" #-}
            negative <- ((\hsc_ptr -> peekByteOff hsc_ptr 16)) (unwrapBN bn) :: IO CInt
{-# LINE 168 "OpenSSL/BN.hsc" #-}
            if negative == 0
               then return $ S# i
               else return $ 0 - (S# i)
    _ -> do
      let !(I# nlimbsi) = fromIntegral nlimbs
          !(I# limbsize) = ((8))
{-# LINE 174 "OpenSSL/BN.hsc" #-}
      (MBA arr) <- newByteArray (nlimbsi *# limbsize)
      (BA ba) <- freezeByteArray arr
      limbs <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) (unwrapBN bn)
{-# LINE 177 "OpenSSL/BN.hsc" #-}
      _ <- _copy_in ba limbs $ fromIntegral $ nlimbs * ((8))
{-# LINE 178 "OpenSSL/BN.hsc" #-}
      negative <- ((\hsc_ptr -> peekByteOff hsc_ptr 16)) (unwrapBN bn) :: IO CInt
{-# LINE 179 "OpenSSL/BN.hsc" #-}
      if negative == 0
         then return $ J# nlimbsi ba
         else return $ 0 - (J# nlimbsi ba)

-- | This is a GHC specific, fast conversion between Integers and OpenSSL
--   bignums. It returns a malloced BigNum.
integerToBN :: Integer -> IO BigNum
integerToBN (S# 0#) = do
  bnptr <- mallocBytes ((24))
{-# LINE 188 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr nullPtr
{-# LINE 189 "OpenSSL/BN.hsc" #-}
  -- This is needed to give GHC enough type information
  let one :: CInt
      one = 1
      zero :: CInt
      zero = 0
  ((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr one
{-# LINE 195 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr zero
{-# LINE 196 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr zero
{-# LINE 197 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr zero
{-# LINE 198 "OpenSSL/BN.hsc" #-}
  return (wrapBN bnptr)

integerToBN (S# v) = do
  bnptr <- mallocBytes ((24))
{-# LINE 202 "OpenSSL/BN.hsc" #-}
  limbs <- malloc :: IO (Ptr CULong)
  poke limbs $ fromIntegral $ abs $ I# v
  ((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr limbs
{-# LINE 205 "OpenSSL/BN.hsc" #-}
  -- This is needed to give GHC enough type information since #poke just
  -- uses an offset
  let one :: CInt
      one = 1
  ((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr one
{-# LINE 210 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr one
{-# LINE 211 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr one
{-# LINE 212 "OpenSSL/BN.hsc" #-}
  ((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr (if (I# v) < 0 then one else 0)
{-# LINE 213 "OpenSSL/BN.hsc" #-}
  return (wrapBN bnptr)

integerToBN v@(J# nlimbs_ bytearray)
  | v >= 0 = do
      let nlimbs = (I# nlimbs_)
      bnptr <- mallocBytes ((24))
{-# LINE 219 "OpenSSL/BN.hsc" #-}
      limbs <- mallocBytes (((8)) * nlimbs)
{-# LINE 220 "OpenSSL/BN.hsc" #-}
      ((\hsc_ptr -> pokeByteOff hsc_ptr 0)) bnptr limbs
{-# LINE 221 "OpenSSL/BN.hsc" #-}
      ((\hsc_ptr -> pokeByteOff hsc_ptr 20)) bnptr (1 :: CInt)
{-# LINE 222 "OpenSSL/BN.hsc" #-}
      _ <- _copy_out limbs bytearray (fromIntegral $ ((8)) * nlimbs)
{-# LINE 223 "OpenSSL/BN.hsc" #-}
      ((\hsc_ptr -> pokeByteOff hsc_ptr 8)) bnptr ((fromIntegral nlimbs) :: CInt)
{-# LINE 224 "OpenSSL/BN.hsc" #-}
      ((\hsc_ptr -> pokeByteOff hsc_ptr 12)) bnptr ((fromIntegral nlimbs) :: CInt)
{-# LINE 225 "OpenSSL/BN.hsc" #-}
      ((\hsc_ptr -> pokeByteOff hsc_ptr 16)) bnptr (0 :: CInt)
{-# LINE 226 "OpenSSL/BN.hsc" #-}
      return (wrapBN bnptr)
  | otherwise = do bnptr <- integerToBN (0-v)
                   ((\hsc_ptr -> pokeByteOff hsc_ptr 16)) (unwrapBN bnptr) (1 :: CInt)
{-# LINE 229 "OpenSSL/BN.hsc" #-}
                   return bnptr


{-# LINE 232 "OpenSSL/BN.hsc" #-}

-- TODO: we could make a function which doesn't even allocate BN data if we
-- wanted to be very fast and dangerout. The BIGNUM could point right into the
-- Integer's data. However, I'm not sure about the semantics of the GC; which
-- might move the Integer data around.

-- |@'withBN' n f@ converts n to a 'BigNum' and computes @f@. Then it
-- frees the 'BigNum'.
withBN :: Integer -> (BigNum -> IO a) -> IO a
withBN dec m = bracket (integerToBN dec) (_free . unwrapBN) m

foreign import ccall unsafe "BN_bn2mpi"
        _bn2mpi :: Ptr BIGNUM -> Ptr CChar -> IO CInt

foreign import ccall unsafe "BN_mpi2bn"
        _mpi2bn :: Ptr CChar -> CInt -> Ptr BIGNUM -> IO (Ptr BIGNUM)

-- |This is an alias to 'bnToInteger'.
peekBN :: BigNum -> IO Integer
peekBN = bnToInteger

-- |This is an alias to 'integerToBN'.
newBN :: Integer -> IO BigNum
newBN = integerToBN

-- | Convert a BigNum to an MPI: a serialisation of large ints which has a
--   4-byte, big endian length followed by the bytes of the number in
--   most-significant-first order.
bnToMPI :: BigNum -> IO BS.ByteString
bnToMPI bn = do
  bytes <- _bn2mpi (unwrapBN bn) nullPtr
  allocaBytes (fromIntegral bytes) (\buffer -> do
    _ <- _bn2mpi (unwrapBN bn) buffer
    BS.packCStringLen (buffer, fromIntegral bytes))

-- | Convert an MPI into a BigNum. See bnToMPI for details of the format
mpiToBN :: BS.ByteString -> IO BigNum
mpiToBN mpi = do
  BS.useAsCStringLen mpi (\(ptr, len) -> do
    _mpi2bn ptr (fromIntegral len) nullPtr) >>= return . wrapBN

-- | Convert an Integer to an MPI. See bnToMPI for the format
integerToMPI :: Integer -> IO BS.ByteString
integerToMPI v = bracket (integerToBN v) (_free . unwrapBN) bnToMPI

-- | Convert an MPI to an Integer. See bnToMPI for the format
mpiToInteger :: BS.ByteString -> IO Integer
mpiToInteger mpi = do
  bn <- mpiToBN mpi
  v <- bnToInteger bn
  _free (unwrapBN bn)
  return v

foreign import ccall unsafe "BN_mod_exp"
        _mod_exp :: Ptr BIGNUM -> Ptr BIGNUM -> Ptr BIGNUM -> Ptr BIGNUM -> BNCtx -> IO (Ptr BIGNUM)

type BNCtx = Ptr BNCTX
data BNCTX

foreign import ccall unsafe "BN_CTX_new"
        _BN_ctx_new :: IO BNCtx

foreign import ccall unsafe "BN_CTX_free"
        _BN_ctx_free :: BNCtx -> IO ()

withBNCtx :: (BNCtx -> IO a) -> IO a
withBNCtx f = bracket _BN_ctx_new _BN_ctx_free f

-- |@'modexp' a p m@ computes @a@ to the @p@-th power modulo @m@.
modexp :: Integer -> Integer -> Integer -> Integer
modexp a p m = unsafePerformIO (do
  withBN a (\bnA -> (do
    withBN p (\bnP -> (do
      withBN m (\bnM -> (do
        withBNCtx (\ctx -> (do
          r <- newBN 0
          _ <- _mod_exp (unwrapBN r) (unwrapBN bnA) (unwrapBN bnP) (unwrapBN bnM) ctx
          bnToInteger r >>= return)))))))))

{- Random Integer generation ------------------------------------------------ -}

foreign import ccall unsafe "BN_rand_range"
        _BN_rand_range :: Ptr BIGNUM -> Ptr BIGNUM -> IO CInt

foreign import ccall unsafe "BN_pseudo_rand_range"
        _BN_pseudo_rand_range :: Ptr BIGNUM -> Ptr BIGNUM -> IO CInt

-- | Return a strongly random number in the range 0 <= x < n where the given
--   filter function returns true.
randIntegerUptoNMinusOneSuchThat :: (Integer -> Bool)  -- ^ a filter function
                                 -> Integer  -- ^ one plus the upper limit
                                 -> IO Integer
randIntegerUptoNMinusOneSuchThat f range = withBN range (\bnRange -> (do
  r <- newBN 0
  let try = do
        _BN_rand_range (unwrapBN r) (unwrapBN bnRange) >>= failIf_ (/= 1)
        i <- bnToInteger r
        if f i
           then return i
           else try
  try))

-- | Return a random number in the range 0 <= x < n where the given
--   filter function returns true.
prandIntegerUptoNMinusOneSuchThat :: (Integer -> Bool)  -- ^ a filter function
                                  -> Integer  -- ^ one plus the upper limit
                                  -> IO Integer
prandIntegerUptoNMinusOneSuchThat f range = withBN range (\bnRange -> (do
  r <- newBN 0
  let try = do
        _BN_rand_range (unwrapBN r) (unwrapBN bnRange) >>= failIf_ (/= 1)
        i <- bnToInteger r
        if f i
           then return i
           else try
  try))

-- | Return a strongly random number in the range 0 <= x < n
randIntegerZeroToNMinusOne :: Integer -> IO Integer
randIntegerZeroToNMinusOne = randIntegerUptoNMinusOneSuchThat (const True)
-- | Return a strongly random number in the range 0 < x < n
randIntegerOneToNMinusOne :: Integer -> IO Integer
randIntegerOneToNMinusOne = randIntegerUptoNMinusOneSuchThat (/= 0)

-- | Return a random number in the range 0 <= x < n
prandIntegerZeroToNMinusOne :: Integer -> IO Integer
prandIntegerZeroToNMinusOne = prandIntegerUptoNMinusOneSuchThat (const True)
-- | Return a random number in the range 0 < x < n
prandIntegerOneToNMinusOne :: Integer -> IO Integer
prandIntegerOneToNMinusOne = prandIntegerUptoNMinusOneSuchThat (/= 0)