module Jose.Internal.Crypto
( hmacSign
, hmacVerify
, rsaSign
, rsaVerify
, rsaEncrypt
, rsaDecrypt
, ecVerify
, encryptPayload
, decryptPayload
, generateCmkAndIV
, keyWrap
, keyUnwrap
, pad
, unpad
)
where
import Control.Monad (when, unless)
import Crypto.Error
import Crypto.Cipher.AES
import Crypto.Cipher.Types
import Crypto.Hash.Algorithms
import Crypto.Number.Serialize (os2ip)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import Crypto.Random (MonadRandom, getRandomBytes)
import Crypto.MAC.HMAC (HMAC (..), hmac)
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Either.Combinators
import qualified Data.Serialize as Serialize
import qualified Data.Text as T
import Data.Word (Word64, Word8)
import Jose.Jwa
import Jose.Types (JwtError(..))
hmacSign :: JwsAlg
-> ByteString
-> ByteString
-> Either JwtError ByteString
hmacSign a k m = case a of
HS256 -> Right $ BA.convert (hmac k m :: HMAC SHA256)
HS384 -> Right $ BA.convert (hmac k m :: HMAC SHA384)
HS512 -> Right $ BA.convert (hmac k m :: HMAC SHA512)
_ -> Left $ BadAlgorithm $ T.pack $ "Not an HMAC algorithm: " ++ show a
hmacVerify :: JwsAlg
-> ByteString
-> ByteString
-> ByteString
-> Bool
hmacVerify a key msg sig = either (const False) (`BA.constEq` sig) $ hmacSign a key msg
rsaSign :: Maybe RSA.Blinder
-> JwsAlg
-> RSA.PrivateKey
-> ByteString
-> Either JwtError ByteString
rsaSign blinder a key msg = case a of
RS256 -> go SHA256
RS384 -> go SHA384
RS512 -> go SHA512
_ -> Left . BadAlgorithm . T.pack $ "Not an RSA algorithm: " ++ show a
where
go h = either (const $ Left BadCrypto) Right $ PKCS15.sign blinder (Just h) key msg
rsaVerify :: JwsAlg
-> RSA.PublicKey
-> ByteString
-> ByteString
-> Bool
rsaVerify a key msg sig = case a of
RS256 -> go SHA256
RS384 -> go SHA384
RS512 -> go SHA512
_ -> False
where
go h = PKCS15.verify (Just h) key msg sig
ecVerify :: JwsAlg
-> ECDSA.PublicKey
-> ByteString
-> ByteString
-> Bool
ecVerify a key msg sig = case a of
ES256 -> go SHA256
ES384 -> go SHA384
ES512 -> go SHA512
_ -> False
where
(r, s) = B.splitAt (B.length sig `div` 2) sig
ecSig = ECDSA.Signature (os2ip r) (os2ip s)
go h = ECDSA.verify h key ecSig msg
generateCmkAndIV :: MonadRandom m
=> Enc
-> m (B.ByteString, B.ByteString)
generateCmkAndIV e = do
cmk <- getRandomBytes (keySize e)
iv <- getRandomBytes (ivSize e)
return (cmk, iv)
where
keySize A128GCM = 16
keySize A192GCM = 24
keySize A256GCM = 32
keySize A128CBC_HS256 = 32
keySize A192CBC_HS384 = 48
keySize A256CBC_HS512 = 64
ivSize A128GCM = 12
ivSize A192GCM = 12
ivSize A256GCM = 12
ivSize _ = 16
rsaEncrypt :: MonadRandom m
=> RSA.PublicKey
-> JweAlg
-> B.ByteString
-> m (Either JwtError B.ByteString)
rsaEncrypt k a bs = case a of
RSA1_5 -> mapErr (PKCS15.encrypt k bs)
RSA_OAEP -> mapErr (OAEP.encrypt (OAEP.defaultOAEPParams SHA1) k bs)
_ -> return (Left (BadAlgorithm "Not an RSA algorithm"))
where
mapErr = fmap (mapLeft (const BadCrypto))
rsaDecrypt :: Maybe RSA.Blinder
-> RSA.PrivateKey
-> JweAlg
-> B.ByteString
-> Either JwtError B.ByteString
rsaDecrypt blinder rsaKey a jweKey = case a of
RSA1_5 -> mapErr (PKCS15.decrypt blinder rsaKey jweKey)
RSA_OAEP -> mapErr (OAEP.decrypt blinder (OAEP.defaultOAEPParams SHA1) rsaKey jweKey)
_ -> Left (BadAlgorithm "Not an RSA algorithm")
where
mapErr = mapLeft (const BadCrypto)
data C c = C
initCipher :: BlockCipher c => C c -> B.ByteString -> Maybe c
initCipher _ k = maybeCryptoError $ cipherInit k
decryptPayload :: Enc
-> ByteString
-> ByteString
-> ByteString
-> AuthTag
-> ByteString
-> Maybe ByteString
decryptPayload enc cek iv aad sig ct = case enc of
A128GCM -> doGCM (C :: C AES128)
A192GCM -> doGCM (C :: C AES192)
A256GCM -> doGCM (C :: C AES256)
A128CBC_HS256 -> doCBC (C :: C AES128) SHA256 16
A192CBC_HS384 -> doCBC (C :: C AES192) SHA384 24
A256CBC_HS512 -> doCBC (C :: C AES256) SHA512 32
where
(cbcMacKey, cbcEncKey) = B.splitAt (B.length cek `div` 2) cek
al = fromIntegral (B.length aad) * 8 :: Word64
doGCM :: BlockCipher c => C c -> Maybe ByteString
doGCM c = do
cipher <- initCipher c cek
aead <- maybeCryptoError (aeadInit AEAD_GCM cipher iv)
aeadSimpleDecrypt aead aad ct (AuthTag $ BA.convert sig)
doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> a -> Int -> Maybe ByteString
doCBC c a tagLen = do
checkMac a tagLen
cipher <- initCipher c cbcEncKey
iv' <- makeIV iv
unless (B.length ct `mod` blockSize cipher == 0) Nothing
unpad $ cbcDecrypt cipher iv' ct
checkMac :: HashAlgorithm a => a -> Int -> Maybe ()
checkMac a l = do
let mac = BA.take l $ BA.convert $ doMac a :: BA.Bytes
unless (sig `BA.constEq` mac) Nothing
doMac :: HashAlgorithm a => a -> HMAC a
doMac _ = hmac cbcMacKey $ B.concat [aad, iv, ct, Serialize.encode al]
encryptPayload :: Enc
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> Maybe (AuthTag, ByteString)
encryptPayload e cek iv aad msg = case e of
A128GCM -> doGCM (C :: C AES128)
A192GCM -> doGCM (C :: C AES192)
A256GCM -> doGCM (C :: C AES256)
A128CBC_HS256 -> doCBC (C :: C AES128) SHA256 16
A192CBC_HS384 -> doCBC (C :: C AES192) SHA384 24
A256CBC_HS512 -> doCBC (C :: C AES256) SHA512 32
where
(cbcMacKey, cbcEncKey) = B.splitAt (B.length cek `div` 2) cek
al = fromIntegral (B.length aad) * 8 :: Word64
doGCM :: BlockCipher c => C c -> Maybe (AuthTag, ByteString)
doGCM c = do
cipher <- initCipher c cek
aead <- maybeCryptoError (aeadInit AEAD_GCM cipher iv)
return $ aeadSimpleEncrypt aead aad msg 16
doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> a -> Int -> Maybe (AuthTag, ByteString)
doCBC c a tagLen = do
cipher <- initCipher c cbcEncKey
iv' <- makeIV iv
let ct = cbcEncrypt cipher iv' (pad msg)
mac = doMac a ct
tag = BA.take tagLen (BA.convert mac)
return (AuthTag tag, ct)
doMac :: HashAlgorithm a => a -> ByteString -> HMAC a
doMac _ ct = hmac cbcMacKey $ B.concat [aad, iv, ct, Serialize.encode al]
unpad :: ByteString -> Maybe ByteString
unpad bs
| padLen > 16 || padLen /= B.length padding = Nothing
| B.any (/= padByte) padding = Nothing
| otherwise = return pt
where
len = B.length bs
padByte = B.last bs
padLen = fromIntegral padByte
(pt, padding) = B.splitAt (len padLen) bs
pad :: ByteString -> ByteString
pad bs = B.append bs padding
where
lastBlockSize = B.length bs `mod` 16
padByte = fromIntegral $ 16 lastBlockSize :: Word8
padding = B.replicate (fromIntegral padByte) padByte
keyWrap :: JweAlg -> ByteString -> ByteString -> Either JwtError ByteString
keyWrap alg kek cek = case alg of
A128KW -> doKeyWrap (C :: C AES128)
A192KW -> doKeyWrap (C :: C AES192)
A256KW -> doKeyWrap (C :: C AES256)
_ -> Left (BadAlgorithm "Not a keywrap algorithm")
where
l = B.length cek
n = l `div` 8
iv = BA.replicate 8 166 :: ByteString
doKeyWrap c = do
when (l < 16 || l `mod` 8 /= 0) (Left (KeyError "Invalid content key"))
cipher <- maybe (Left (KeyError "cipher initialization failed")) return $ initCipher c kek
let p = toBlocks cek
(r0, r) = foldl (doRound (ecbEncrypt cipher) 1) (iv, p) [0..5]
Right $ B.concat (r0 : r)
doRound _ _ (a, []) _ = (a, [])
doRound enc i (a, r:rs) j =
let b = enc $ B.concat [a, r]
t = fromIntegral ((n*j) + i) :: Word8
a' = txor t (B.take 8 b)
r' = B.drop 8 b
next = doRound enc (i+1) (a', rs) j
in (fst next, r' : snd next)
txor t b = B.snoc (B.init b) (B.last b `xor` t)
toBlocks :: ByteString -> [ByteString]
toBlocks bytes
| bytes == B.empty = []
| otherwise = let (b, bs') = B.splitAt 8 bytes
in b : toBlocks bs'
keyUnwrap :: ByteString -> JweAlg -> ByteString -> Either JwtError ByteString
keyUnwrap kek alg encK = case alg of
A128KW -> doUnWrap (C :: C AES128)
A192KW -> doUnWrap (C :: C AES192)
A256KW -> doUnWrap (C :: C AES256)
_ -> Left (BadAlgorithm "Not a keywrap algorithm")
where
l = B.length encK
n = (l `div` 8) 1
iv = BA.replicate 8 166 :: ByteString
doUnWrap c = do
when (l < 24 || l `mod` 8 /= 0) (Left BadCrypto)
cipher <- maybe (Left BadCrypto) return $ initCipher c kek
let r = toBlocks encK
(p0, p) = foldl (doRound (ecbDecrypt cipher) n) (head r, reverse (tail r)) (reverse [0..5])
unless (p0 == iv) (Left BadCrypto)
Right $ B.concat (reverse p)
doRound _ _ (a, []) _ = (a, [])
doRound dec i (a, r:rs) j =
let b = dec $ B.concat [txor t a, r]
t = fromIntegral ((n*j) + i) :: Word8
a' = B.take 8 b
r' = B.drop 8 b
next = doRound dec (i1) (a', rs) j
in (fst next, r' : snd next)
txor t b = B.snoc (B.init b) (B.last b `xor` t)