module Crypto.PubKey.RSA.PSS
( PSSParams(..)
, defaultPSSParams
, defaultPSSParamsSHA1
, signWithSalt
, signDigestWithSalt
, sign
, signDigest
, signSafer
, signDigestSafer
, verify
, verifyDigest
) where
import Crypto.Random.Types
import Crypto.PubKey.RSA.Types
import Crypto.PubKey.RSA.Prim
import Crypto.PubKey.RSA (generateBlinder)
import Crypto.PubKey.MaskGenFunction
import Crypto.Hash
import Crypto.Number.Basic (numBits)
import Data.Bits (xor, shiftR, (.&.))
import Data.Word
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B (convert, eq)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
data PSSParams hash seed output = PSSParams
{ pssHash :: hash
, pssMaskGenAlg :: MaskGenAlgorithm seed output
, pssSaltLength :: Int
, pssTrailerField :: Word8
}
defaultPSSParams :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hash)
=> hash
-> PSSParams hash seed output
defaultPSSParams hashAlg =
PSSParams { pssHash = hashAlg
, pssMaskGenAlg = mgf1 hashAlg
, pssSaltLength = hashDigestSize hashAlg
, pssTrailerField = 0xbc
}
defaultPSSParamsSHA1 :: PSSParams SHA1 ByteString ByteString
defaultPSSParamsSHA1 = defaultPSSParams SHA1
signDigestWithSalt :: HashAlgorithm hash
=> ByteString
-> Maybe Blinder
-> PSSParams hash ByteString ByteString
-> PrivateKey
-> Digest hash
-> Either Error ByteString
signDigestWithSalt salt blinder params pk digest
| emLen < hashLen + saltLen + 2 = Left InvalidParameters
| otherwise = Right $ dp blinder pk em
where k = private_size pk
emLen = if emTruncate pubBits then k - 1 else k
mHash = B.convert digest
dbLen = emLen - hashLen - 1
saltLen = B.length salt
hashLen = hashDigestSize (pssHash params)
pubBits = numBits (private_n pk)
m' = B.concat [B.replicate 8 0,mHash,salt]
h = B.convert $ hashWith (pssHash params) m'
db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt]
dbmask = pssMaskGenAlg params h dbLen
maskedDB = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor db dbmask
em = B.concat [maskedDB, h, B.singleton (pssTrailerField params)]
signWithSalt :: HashAlgorithm hash
=> ByteString
-> Maybe Blinder
-> PSSParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
signWithSalt salt blinder params pk m = signDigestWithSalt salt blinder params pk mHash
where mHash = hashWith (pssHash params) m
sign :: (HashAlgorithm hash, MonadRandom m)
=> Maybe Blinder
-> PSSParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> m (Either Error ByteString)
sign blinder params pk m = do
salt <- getRandomBytes (pssSaltLength params)
return (signWithSalt salt blinder params pk m)
signDigest :: (HashAlgorithm hash, MonadRandom m)
=> Maybe Blinder
-> PSSParams hash ByteString ByteString
-> PrivateKey
-> Digest hash
-> m (Either Error ByteString)
signDigest blinder params pk digest = do
salt <- getRandomBytes (pssSaltLength params)
return (signDigestWithSalt salt blinder params pk digest)
signSafer :: (HashAlgorithm hash, MonadRandom m)
=> PSSParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> m (Either Error ByteString)
signSafer params pk m = do
blinder <- generateBlinder (private_n pk)
sign (Just blinder) params pk m
signDigestSafer :: (HashAlgorithm hash, MonadRandom m)
=> PSSParams hash ByteString ByteString
-> PrivateKey
-> Digest hash
-> m (Either Error ByteString)
signDigestSafer params pk digest = do
blinder <- generateBlinder (private_n pk)
signDigest (Just blinder) params pk digest
verify :: HashAlgorithm hash
=> PSSParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> ByteString
-> Bool
verify params pk m = verifyDigest params pk mHash
where mHash = hashWith (pssHash params) m
verifyDigest :: HashAlgorithm hash
=> PSSParams hash ByteString ByteString
-> PublicKey
-> Digest hash
-> ByteString
-> Bool
verifyDigest params pk digest s
| B.length s /= k = False
| B.any (/= 0) pre = False
| B.last em /= pssTrailerField params = False
| B.any (/= 0) ps0 = False
| b1 /= B.singleton 1 = False
| otherwise = B.eq h h'
where
hashLen = hashDigestSize (pssHash params)
mHash = B.convert digest
k = public_size pk
emLen = if emTruncate pubBits then k - 1 else k
dbLen = emLen - hashLen - 1
pubBits = numBits (public_n pk)
(pre, em) = B.splitAt (k - emLen) (ep pk s)
maskedDB = B.take dbLen em
h = B.take hashLen $ B.drop (B.length maskedDB) em
dbmask = pssMaskGenAlg params h dbLen
db = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask
(ps0,z) = B.break (== 1) db
(b1,salt) = B.splitAt 1 z
m' = B.concat [B.replicate 8 0,mHash,salt]
h' = hashWith (pssHash params) m'
emTruncate :: Int -> Bool
emTruncate bits = ((bits-1) .&. 0x7) == 0
normalizeToKeySize :: Int -> [Word8] -> [Word8]
normalizeToKeySize _ [] = []
normalizeToKeySize bits (x:xs) = x .&. mask : xs
where mask = if sh > 0 then 0xff `shiftR` (8-sh) else 0xff
sh = (bits-1) .&. 0x7