{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Internal.Crypto
( hmacSign
, hmacVerify
, rsaSign
, rsaVerify
, rsaEncrypt
, rsaDecrypt
, ecVerify
, encryptPayload
, decryptPayload
, generateCmkAndIV
, keyWrap
, keyUnwrap
, pad
, unpad
)
where
import Control.Applicative
import Control.Monad (when, unless)
import Crypto.Error
import Crypto.Cipher.AES
import Crypto.Cipher.Types hiding (IV)
import Crypto.Hash.Algorithms
import Crypto.Number.Serialize (os2ip)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import Crypto.Random (MonadRandom, getRandomBytes)
import Crypto.MAC.HMAC (HMAC (..), hmac)
import Data.Bits (xor)
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Either.Combinators
import qualified Data.Serialize as Serialize
import qualified Data.Text as T
import Data.Word (Word64, Word8)
import Jose.Jwa
import Jose.Types (JwtError(..))
import Jose.Internal.Parser (IV(..), Tag(..))
hmacSign :: JwsAlg
-> ByteString
-> ByteString
-> Either JwtError ByteString
hmacSign a k m = case a of
HS256 -> Right $ BA.convert (hmac k m :: HMAC SHA256)
HS384 -> Right $ BA.convert (hmac k m :: HMAC SHA384)
HS512 -> Right $ BA.convert (hmac k m :: HMAC SHA512)
_ -> Left $ BadAlgorithm $ T.pack $ "Not an HMAC algorithm: " ++ show a
hmacVerify :: JwsAlg
-> ByteString
-> ByteString
-> ByteString
-> Bool
hmacVerify a key msg sig = either (const False) (`BA.constEq` sig) $ hmacSign a key msg
rsaSign :: Maybe RSA.Blinder
-> JwsAlg
-> RSA.PrivateKey
-> ByteString
-> Either JwtError ByteString
rsaSign blinder a key msg = case a of
RS256 -> go SHA256
RS384 -> go SHA384
RS512 -> go SHA512
_ -> Left . BadAlgorithm . T.pack $ "Not an RSA algorithm: " ++ show a
where
go h = either (const $ Left BadCrypto) Right $ PKCS15.sign blinder (Just h) key msg
rsaVerify :: JwsAlg
-> RSA.PublicKey
-> ByteString
-> ByteString
-> Bool
rsaVerify a key msg sig = case a of
RS256 -> go SHA256
RS384 -> go SHA384
RS512 -> go SHA512
_ -> False
where
go h = PKCS15.verify (Just h) key msg sig
ecVerify :: JwsAlg
-> ECDSA.PublicKey
-> ByteString
-> ByteString
-> Bool
ecVerify a key msg sig = case a of
ES256 -> go SHA256
ES384 -> go SHA384
ES512 -> go SHA512
_ -> False
where
(r, s) = B.splitAt (B.length sig `div` 2) sig
ecSig = ECDSA.Signature (os2ip r) (os2ip s)
go h = ECDSA.verify h key ecSig msg
generateCmkAndIV :: MonadRandom m
=> Enc
-> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV e = do
cmk <- getRandomBytes (keySize e)
iv <- getRandomBytes (ivSize e)
return (cmk, iv)
where
keySize A128GCM = 16
keySize A192GCM = 24
keySize A256GCM = 32
keySize A128CBC_HS256 = 32
keySize A192CBC_HS384 = 48
keySize A256CBC_HS512 = 64
ivSize A128GCM = 12
ivSize A192GCM = 12
ivSize A256GCM = 12
ivSize _ = 16
rsaEncrypt :: (MonadRandom m, ByteArray msg, ByteArray out)
=> RSA.PublicKey
-> JweAlg
-> msg
-> m (Either JwtError out)
rsaEncrypt k a msg = fmap BA.convert <$> case a of
RSA1_5 -> mapErr (PKCS15.encrypt k bs)
RSA_OAEP -> mapErr (OAEP.encrypt (OAEP.defaultOAEPParams SHA1) k bs)
RSA_OAEP_256 -> mapErr (OAEP.encrypt (OAEP.defaultOAEPParams SHA256) k bs)
_ -> return (Left (BadAlgorithm "Not an RSA algorithm"))
where
bs = BA.convert msg
mapErr = fmap (mapLeft (const BadCrypto))
rsaDecrypt :: ByteArray ct
=> Maybe RSA.Blinder
-> RSA.PrivateKey
-> JweAlg
-> ct
-> Either JwtError ScrubbedBytes
rsaDecrypt blinder rsaKey a ct = BA.convert <$> case a of
RSA1_5 -> mapErr (PKCS15.decrypt blinder rsaKey bs)
RSA_OAEP -> mapErr (OAEP.decrypt blinder (OAEP.defaultOAEPParams SHA1) rsaKey bs)
RSA_OAEP_256 -> mapErr (OAEP.decrypt blinder (OAEP.defaultOAEPParams SHA256) rsaKey bs)
_ -> Left (BadAlgorithm "Not an RSA algorithm")
where
bs = BA.convert ct
mapErr = mapLeft (const BadCrypto)
data C c = C
initCipher :: BlockCipher c => C c -> ScrubbedBytes -> Either JwtError c
initCipher _ k = mapFail (cipherInit k)
mapFail :: CryptoFailable a -> Either JwtError a
mapFail (CryptoPassed a) = return a
mapFail (CryptoFailed e) = Left $ case e of
CryptoError_KeySizeInvalid -> KeyError "cipher key length is invalid"
_ -> BadCrypto
decryptPayload :: forall ba. (ByteArray ba)
=> Enc
-> ScrubbedBytes
-> IV
-> ba
-> Tag
-> ba
-> Maybe ba
decryptPayload enc cek iv_ aad tag_ ct = case (enc, iv_, tag_) of
(A128GCM, IV12 b, Tag16 t) -> doGCM (C :: C AES128) b t
(A192GCM, IV12 b, Tag16 t) -> doGCM (C :: C AES192) b t
(A256GCM, IV12 b, Tag16 t) -> doGCM (C :: C AES256) b t
(A128CBC_HS256, IV16 b, Tag16 t) -> doCBC (C :: C AES128) b t SHA256 16
(A192CBC_HS384, IV16 b, Tag24 t) -> doCBC (C :: C AES192) b t SHA384 24
(A256CBC_HS512, IV16 b, Tag32 t) -> doCBC (C :: C AES256) b t SHA512 32
_ -> Nothing
where
(cbcMacKey, cbcEncKey) = BA.splitAt (BA.length cek `div` 2) cek :: (ScrubbedBytes, ScrubbedBytes)
al = fromIntegral (BA.length aad) * 8 :: Word64
doGCM :: BlockCipher c => C c -> ByteString -> ByteString -> Maybe ba
doGCM c iv tag = do
cipher <- rightToMaybe (initCipher c cek)
aead <- maybeCryptoError (aeadInit AEAD_GCM cipher iv)
aeadSimpleDecrypt aead aad ct (AuthTag $ BA.convert tag)
doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
doCBC c iv tag a tagLen = do
checkMac a tag iv tagLen
cipher <- rightToMaybe (initCipher c cbcEncKey)
iv' <- makeIV iv
unless (BA.length ct `mod` blockSize cipher == 0) Nothing
unpad $ cbcDecrypt cipher iv' ct
checkMac :: HashAlgorithm a => a -> ByteString -> ByteString -> Int -> Maybe ()
checkMac a tag iv l = do
let mac = BA.take l $ BA.convert $ doMac a iv :: BA.Bytes
unless (tag `BA.constEq` mac) Nothing
doMac :: HashAlgorithm a => a -> ByteString -> HMAC a
doMac _ iv = hmac cbcMacKey (BA.concat [BA.convert aad, iv, BA.convert ct, Serialize.encode al] :: ByteString)
encryptPayload :: forall ba iv. (ByteArray ba, ByteArray iv)
=> Enc
-> ScrubbedBytes
-> iv
-> ba
-> ba
-> Maybe (AuthTag, ba)
encryptPayload e cek iv aad msg = case e of
A128GCM -> doGCM (C :: C AES128)
A192GCM -> doGCM (C :: C AES192)
A256GCM -> doGCM (C :: C AES256)
A128CBC_HS256 -> doCBC (C :: C AES128) SHA256 16
A192CBC_HS384 -> doCBC (C :: C AES192) SHA384 24
A256CBC_HS512 -> doCBC (C :: C AES256) SHA512 32
where
(cbcMacKey, cbcEncKey) = BA.splitAt (BA.length cek `div` 2) cek :: (ScrubbedBytes, ScrubbedBytes)
al = fromIntegral (BA.length aad) * 8 :: Word64
doGCM c = do
cipher <- rightToMaybe (initCipher c cek)
aead <- maybeCryptoError (aeadInit AEAD_GCM cipher iv)
return $ aeadSimpleEncrypt aead aad msg 16
doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> a -> Int -> Maybe (AuthTag, ba)
doCBC c a tagLen = do
cipher <- rightToMaybe (initCipher c cbcEncKey)
iv' <- makeIV iv
let ct = cbcEncrypt cipher iv' (pad msg)
mac = doMac a ct
tag = BA.take tagLen (BA.convert mac)
return (AuthTag tag, ct)
doMac :: HashAlgorithm a => a -> ba -> HMAC a
doMac _ ct = hmac cbcMacKey (BA.concat [BA.convert aad, BA.convert iv, BA.convert ct, Serialize.encode al] :: ByteString)
unpad :: (ByteArray ba) => ba -> Maybe ba
unpad bs
| padLen > 16 || padLen /= BA.length padding = Nothing
| BA.any (/= padByte) padding = Nothing
| otherwise = return pt
where
len = BA.length bs
padByte = BA.index bs (len-1)
padLen = fromIntegral padByte
(pt, padding) = BA.splitAt (len - padLen) bs
pad :: (ByteArray ba) => ba -> ba
pad bs = BA.append bs padding
where
lastBlockSize = BA.length bs `mod` 16
padByte = fromIntegral $ 16 - lastBlockSize :: Word8
padding = BA.replicate (fromIntegral padByte) padByte
keyWrap :: ByteArray ba => JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap alg kek cek = case alg of
A128KW -> doKeyWrap (C :: C AES128)
A192KW -> doKeyWrap (C :: C AES192)
A256KW -> doKeyWrap (C :: C AES256)
_ -> Left (BadAlgorithm "Not a keywrap algorithm")
where
l = BA.length cek
n = l `div` 8
iv = BA.replicate 8 166 :: ByteString
doKeyWrap c = do
when (l < 16 || l `mod` 8 /= 0) (Left (KeyError "Invalid content key"))
cipher <- initCipher c kek
let p = toBlocks cek
(r0, r) = foldl (doRound (ecbEncrypt cipher) 1) (BA.convert iv, p) [0..5]
Right $ BA.concat (r0 : r)
doRound _ _ (a, []) _ = (a, [])
doRound enc i (a, r:rs) j =
let b = enc $ BA.concat [a, r]
t = fromIntegral ((n*j) + i) :: Word8
a' = txor t (BA.take 8 b)
r' = BA.drop 8 b
next = doRound enc (i+1) (a', rs) j
in (fst next, r' : snd next)
txor :: ByteArray ba => Word8 -> ba -> ba
txor t b =
let n = BA.length b
lastByte = BA.index b (n-1)
initBytes = BA.take (n-1) b
in BA.snoc initBytes (lastByte `xor` t)
toBlocks :: ByteArray ba => ba -> [ba]
toBlocks bytes
| BA.null bytes = []
| otherwise = let (b, bs') = BA.splitAt 8 bytes
in b : toBlocks bs'
keyUnwrap :: ByteArray ba => ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap kek alg encK = case alg of
A128KW -> doUnWrap (C :: C AES128)
A192KW -> doUnWrap (C :: C AES192)
A256KW -> doUnWrap (C :: C AES256)
_ -> Left (BadAlgorithm "Not a keywrap algorithm")
where
l = BA.length encK
n = (l `div` 8) - 1
iv = BA.replicate 8 166
doUnWrap c = do
when (l < 24 || l `mod` 8 /= 0) (Left BadCrypto)
cipher <- initCipher c kek
let r = toBlocks encK
(p0, p) = foldl (doRound (ecbDecrypt cipher) n) (head r, reverse (tail r)) (reverse [0..5])
unless (p0 == iv) (Left BadCrypto)
Right $ BA.concat (reverse p)
doRound _ _ (a, []) _ = (a, [])
doRound dec i (a, r:rs) j =
let b = dec $ BA.concat [txor t a, r]
t = fromIntegral ((n*j) + i) :: Word8
a' = BA.take 8 b
r' = BA.drop 8 b
next = doRound dec (i-1) (a', rs) j
in (fst next, r' : snd next)