{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Crypto.PubKey.ECDSA
( EllipticCurveECDSA (..)
, PublicKey
, encodePublic
, decodePublic
, toPublic
, PrivateKey
, encodePrivate
, decodePrivate
, Signature(..)
, signatureFromIntegers
, signatureToIntegers
, signWith
, signDigestWith
, sign
, signDigest
, verify
, verifyDigest
) where
import Control.Monad
import Crypto.ECC
import qualified Crypto.ECC.Simple.Types as Simple
import Crypto.Error
import Crypto.Hash
import Crypto.Hash.Types
import Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess)
import Crypto.Internal.Imports
import Crypto.Number.ModArithmetic (inverseFermat)
import qualified Crypto.PubKey.ECC.P256 as P256
import Crypto.Random.Types
import Data.Bits
import qualified Data.ByteArray as B
import Data.Data
import Foreign.Ptr (Ptr)
import Foreign.Storable (peekByteOff, pokeByteOff)
data Signature curve = Signature
{ sign_r :: Scalar curve
, sign_s :: Scalar curve
}
deriving instance Eq (Scalar curve) => Eq (Signature curve)
deriving instance Show (Scalar curve) => Show (Signature curve)
instance NFData (Scalar curve) => NFData (Signature curve) where
rnf (Signature r s) = rnf r `seq` rnf s `seq` ()
type PublicKey curve = Point curve
type PrivateKey curve = Scalar curve
class EllipticCurveBasepointArith curve => EllipticCurveECDSA curve where
scalarIsValid :: proxy curve -> Scalar curve -> Bool
scalarIsZero :: proxy curve -> Scalar curve -> Bool
scalarIsZero prx s = s == throwCryptoError (scalarFromInteger prx 0)
scalarInv :: proxy curve -> Scalar curve -> Maybe (Scalar curve)
pointX :: proxy curve -> Point curve -> Maybe (Scalar curve)
instance EllipticCurveECDSA Curve_P256R1 where
scalarIsValid _ s = not (P256.scalarIsZero s)
&& P256.scalarCmp s P256.scalarN == LT
scalarIsZero _ = P256.scalarIsZero
scalarInv _ s = let inv = P256.scalarInvSafe s
in if P256.scalarIsZero inv then Nothing else Just inv
pointX _ = P256.pointX
instance EllipticCurveECDSA Curve_P384R1 where
scalarIsValid _ = ecScalarIsValid (Proxy :: Proxy Simple.SEC_p384r1)
scalarIsZero _ = ecScalarIsZero
scalarInv _ = ecScalarInv (Proxy :: Proxy Simple.SEC_p384r1)
pointX _ = ecPointX (Proxy :: Proxy Simple.SEC_p384r1)
instance EllipticCurveECDSA Curve_P521R1 where
scalarIsValid _ = ecScalarIsValid (Proxy :: Proxy Simple.SEC_p521r1)
scalarIsZero _ = ecScalarIsZero
scalarInv _ = ecScalarInv (Proxy :: Proxy Simple.SEC_p521r1)
pointX _ = ecPointX (Proxy :: Proxy Simple.SEC_p521r1)
signatureFromIntegers :: EllipticCurveECDSA curve
=> proxy curve -> (Integer, Integer) -> CryptoFailable (Signature curve)
signatureFromIntegers prx (r, s) =
liftA2 Signature (scalarFromInteger prx r) (scalarFromInteger prx s)
signatureToIntegers :: EllipticCurveECDSA curve
=> proxy curve -> Signature curve -> (Integer, Integer)
signatureToIntegers prx sig =
(scalarToInteger prx $ sign_r sig, scalarToInteger prx $ sign_s sig)
encodePublic :: (EllipticCurve curve, ByteArray bs)
=> proxy curve -> PublicKey curve -> bs
encodePublic = encodePoint
decodePublic :: (EllipticCurve curve, ByteArray bs)
=> proxy curve -> bs -> CryptoFailable (PublicKey curve)
decodePublic = decodePoint
encodePrivate :: (EllipticCurveECDSA curve, ByteArray bs)
=> proxy curve -> PrivateKey curve -> bs
encodePrivate = encodeScalar
decodePrivate :: (EllipticCurveECDSA curve, ByteArray bs)
=> proxy curve -> bs -> CryptoFailable (PrivateKey curve)
decodePrivate = decodeScalar
toPublic :: EllipticCurveECDSA curve
=> proxy curve -> PrivateKey curve -> PublicKey curve
toPublic = pointBaseSmul
signDigestWith :: (EllipticCurveECDSA curve, HashAlgorithm hash)
=> proxy curve -> Scalar curve -> PrivateKey curve -> Digest hash -> Maybe (Signature curve)
signDigestWith prx k d digest = do
let z = tHashDigest prx digest
point = pointBaseSmul prx k
r <- pointX prx point
kInv <- scalarInv prx k
let s = scalarMul prx kInv (scalarAdd prx z (scalarMul prx r d))
when (scalarIsZero prx r || scalarIsZero prx s) Nothing
return $ Signature r s
signWith :: (EllipticCurveECDSA curve, ByteArrayAccess msg, HashAlgorithm hash)
=> proxy curve -> Scalar curve -> PrivateKey curve -> hash -> msg -> Maybe (Signature curve)
signWith prx k d hashAlg msg = signDigestWith prx k d (hashWith hashAlg msg)
signDigest :: (EllipticCurveECDSA curve, MonadRandom m, HashAlgorithm hash)
=> proxy curve -> PrivateKey curve -> Digest hash -> m (Signature curve)
signDigest prx pk digest = do
k <- curveGenerateScalar prx
case signDigestWith prx k pk digest of
Nothing -> signDigest prx pk digest
Just sig -> return sig
sign :: (EllipticCurveECDSA curve, MonadRandom m, ByteArrayAccess msg, HashAlgorithm hash)
=> proxy curve -> PrivateKey curve -> hash -> msg -> m (Signature curve)
sign prx pk hashAlg msg = signDigest prx pk (hashWith hashAlg msg)
verifyDigest :: (EllipticCurveECDSA curve, HashAlgorithm hash)
=> proxy curve -> PublicKey curve -> Signature curve -> Digest hash -> Bool
verifyDigest prx q (Signature r s) digest
| not (scalarIsValid prx r) = False
| not (scalarIsValid prx s) = False
| otherwise = maybe False (r ==) $ do
w <- scalarInv prx s
let z = tHashDigest prx digest
u1 = scalarMul prx z w
u2 = scalarMul prx r w
x = pointsSmulVarTime prx u1 u2 q
pointX prx x
verify :: (EllipticCurveECDSA curve, ByteArrayAccess msg, HashAlgorithm hash)
=> proxy curve -> hash -> PublicKey curve -> Signature curve -> msg -> Bool
verify prx hashAlg q sig msg = verifyDigest prx q sig (hashWith hashAlg msg)
tHashDigest :: (EllipticCurveECDSA curve, HashAlgorithm hash)
=> proxy curve -> Digest hash -> Scalar curve
tHashDigest prx (Digest digest) = throwCryptoError $ decodeScalar prx encoded
where m = curveOrderBits prx
d = m - B.length digest * 8
(n, r) = m `divMod` 8
n' = if r > 0 then succ n else n
encoded
| d > 0 = B.zero (n' - B.length digest) `B.append` digest
| d == 0 = digest
| r == 0 = B.take n digest
| otherwise = shiftBytes digest
shiftBytes bs = B.allocAndFreeze n' $ \dst ->
B.withByteArray bs $ \src -> go dst src 0 0
go :: Ptr Word8 -> Ptr Word8 -> Word8 -> Int -> IO ()
go dst src !a i
| i >= n' = return ()
| otherwise = do
b <- peekByteOff src i
pokeByteOff dst i (unsafeShiftR b (8 - r) .|. unsafeShiftL a r)
go dst src b (succ i)
ecScalarIsValid :: Simple.Curve c => proxy c -> Simple.Scalar c -> Bool
ecScalarIsValid prx (Simple.Scalar s) = s > 0 && s < n
where n = Simple.curveEccN $ Simple.curveParameters prx
ecScalarIsZero :: forall curve . Simple.Curve curve
=> Simple.Scalar curve -> Bool
ecScalarIsZero (Simple.Scalar a) = a == 0
ecScalarInv :: Simple.Curve c
=> proxy c -> Simple.Scalar c -> Maybe (Simple.Scalar c)
ecScalarInv prx (Simple.Scalar s)
| i == 0 = Nothing
| otherwise = Just $ Simple.Scalar i
where n = Simple.curveEccN $ Simple.curveParameters prx
i = inverseFermat s n
ecPointX :: Simple.Curve c
=> proxy c -> Simple.Point c -> Maybe (Simple.Scalar c)
ecPointX _ Simple.PointO = Nothing
ecPointX prx (Simple.Point x _) = Just (Simple.Scalar $ x `mod` n)
where n = Simple.curveEccN $ Simple.curveParameters prx