{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_HADDOCK prune #-}

-- | Internal functions for encrypting and signing / decrypting
-- and verifying JWT content.

module Jose.Internal.Crypto
    ( hmacSign
    , hmacVerify
    , ed25519Verify
    , ed448Verify
    , rsaSign
    , rsaVerify
    , rsaEncrypt
    , rsaDecrypt
    , ecVerify
    , encryptPayload
    , decryptPayload
    , generateCmkAndIV
    , keyWrap
    , keyUnwrap
    , pad
    , unpad
    )
where


import           Control.Applicative
import           Control.Monad (when, unless)
import           Crypto.Error
import           Crypto.Cipher.AES
import           Crypto.Cipher.Types hiding (IV)
import           Crypto.Hash.Algorithms
import           Crypto.Number.Serialize (os2ip)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.Ed25519 as Ed25519
import qualified Crypto.PubKey.Ed448 as Ed448
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import           Crypto.Random (MonadRandom, getRandomBytes)
import           Crypto.MAC.HMAC (HMAC (..), hmac)
import           Data.Bits (xor)
import           Data.Bifunctor (first)
import           Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.Serialize as Serialize
import qualified Data.Text as T
import           Data.Word (Word64, Word8)

import           Jose.Jwa
import           Jose.Types (JwtError(..))
import           Jose.Internal.Parser (IV(..), Tag(..))

rightToMaybe :: Either a b -> Maybe b
rightToMaybe :: Either a b -> Maybe b
rightToMaybe (Right b
x) = b -> Maybe b
forall a. a -> Maybe a
Just b
x
rightToMaybe Left{}    = Maybe b
forall a. Maybe a
Nothing

-- | Sign a message with an HMAC key.
hmacSign :: JwsAlg      -- ^ HMAC algorithm to use
         -> ByteString  -- ^ Key
         -> ByteString  -- ^ The message/content
         -> Either JwtError ByteString -- ^ HMAC output
hmacSign :: JwsAlg -> ByteString -> ByteString -> Either JwtError ByteString
hmacSign JwsAlg
a ByteString
k ByteString
m = case JwsAlg
a of
    JwsAlg
HS256 -> ByteString -> Either JwtError ByteString
forall a b. b -> Either a b
Right (ByteString -> Either JwtError ByteString)
-> ByteString -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ HMAC SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ByteString -> HMAC SHA256
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ByteString
k ByteString
m :: HMAC SHA256)
    JwsAlg
HS384 -> ByteString -> Either JwtError ByteString
forall a b. b -> Either a b
Right (ByteString -> Either JwtError ByteString)
-> ByteString -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ HMAC SHA384 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ByteString -> HMAC SHA384
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ByteString
k ByteString
m :: HMAC SHA384)
    JwsAlg
HS512 -> ByteString -> Either JwtError ByteString
forall a b. b -> Either a b
Right (ByteString -> Either JwtError ByteString)
-> ByteString -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ HMAC SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ByteString -> HMAC SHA512
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ByteString
k ByteString
m :: HMAC SHA512)
    JwsAlg
_     -> JwtError -> Either JwtError ByteString
forall a b. a -> Either a b
Left (JwtError -> Either JwtError ByteString)
-> JwtError -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
BadAlgorithm (Text -> JwtError) -> Text -> JwtError
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"Not an HMAC algorithm: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ JwsAlg -> String
forall a. Show a => a -> String
show JwsAlg
a

-- | Verify the HMAC for a given message.
-- Returns false if the MAC is incorrect or the 'Alg' is not an HMAC.
hmacVerify :: JwsAlg      -- ^ HMAC Algorithm to use
           -> ByteString  -- ^ Key
           -> ByteString  -- ^ The message/content
           -> ByteString  -- ^ The signature to check
           -> Bool        -- ^ Whether the signature is correct
hmacVerify :: JwsAlg -> ByteString -> ByteString -> ByteString -> Bool
hmacVerify JwsAlg
a ByteString
key ByteString
msg ByteString
sig = (JwtError -> Bool)
-> (ByteString -> Bool) -> Either JwtError ByteString -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Bool -> JwtError -> Bool
forall a b. a -> b -> a
const Bool
False) (ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`BA.constEq` ByteString
sig) (Either JwtError ByteString -> Bool)
-> Either JwtError ByteString -> Bool
forall a b. (a -> b) -> a -> b
$ JwsAlg -> ByteString -> ByteString -> Either JwtError ByteString
hmacSign JwsAlg
a ByteString
key ByteString
msg


-- | Verify an Ed25519 signed message
ed25519Verify :: JwsAlg
              -> Ed25519.PublicKey
              -> ByteString
              -- ^ The message/content
              -> ByteString
              -- ^ The signature to check
              -> Bool
              -- ^ Whether the signature is correct
ed25519Verify :: JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
ed25519Verify JwsAlg
EdDSA PublicKey
pubKey ByteString
msg ByteString
sig =
    case ByteString -> CryptoFailable Signature
forall ba. ByteArrayAccess ba => ba -> CryptoFailable Signature
Ed25519.signature ByteString
sig of
       CryptoPassed Signature
sig_ ->
         PublicKey -> ByteString -> Signature -> Bool
forall ba.
ByteArrayAccess ba =>
PublicKey -> ba -> Signature -> Bool
Ed25519.verify PublicKey
pubKey ByteString
msg Signature
sig_
       CryptoFailable Signature
_ -> Bool
False
ed25519Verify JwsAlg
_ PublicKey
_ ByteString
_ ByteString
_ = Bool
False


-- | Verify an Ed448 signed message
ed448Verify :: JwsAlg
            -> Ed448.PublicKey
            -> ByteString
            -- ^ The message/content
            -> ByteString
            -- ^ The signature to check
            -> Bool
            -- ^ Whether the signature is correct
ed448Verify :: JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
ed448Verify JwsAlg
EdDSA PublicKey
pubKey ByteString
msg ByteString
sig =
    case ByteString -> CryptoFailable Signature
forall ba. ByteArrayAccess ba => ba -> CryptoFailable Signature
Ed448.signature ByteString
sig of
       CryptoPassed Signature
sig_ ->
         PublicKey -> ByteString -> Signature -> Bool
forall ba.
ByteArrayAccess ba =>
PublicKey -> ba -> Signature -> Bool
Ed448.verify PublicKey
pubKey ByteString
msg Signature
sig_
       CryptoFailable Signature
_ -> Bool
False
ed448Verify JwsAlg
_ PublicKey
_ ByteString
_ ByteString
_ = Bool
False


-- | Sign a message using an RSA private key.
--
-- The failure condition should only occur if the algorithm is not an RSA
-- algorithm, or the RSA key is too small, causing the padding of the
-- signature to fail. With real-world RSA keys this shouldn't happen in practice.
rsaSign :: Maybe RSA.Blinder  -- ^ RSA blinder
        -> JwsAlg             -- ^ Algorithm to use. Must be one of @RSA256@, @RSA384@ or @RSA512@
        -> RSA.PrivateKey     -- ^ Private key to sign with
        -> ByteString         -- ^ Message to sign
        -> Either JwtError ByteString    -- ^ The signature
rsaSign :: Maybe Blinder
-> JwsAlg -> PrivateKey -> ByteString -> Either JwtError ByteString
rsaSign Maybe Blinder
blinder JwsAlg
a PrivateKey
key ByteString
msg = case JwsAlg
a of
    JwsAlg
RS256 -> SHA256 -> Either JwtError ByteString
forall hashAlg.
HashAlgorithmASN1 hashAlg =>
hashAlg -> Either JwtError ByteString
go SHA256
SHA256
    JwsAlg
RS384 -> SHA384 -> Either JwtError ByteString
forall hashAlg.
HashAlgorithmASN1 hashAlg =>
hashAlg -> Either JwtError ByteString
go SHA384
SHA384
    JwsAlg
RS512 -> SHA512 -> Either JwtError ByteString
forall hashAlg.
HashAlgorithmASN1 hashAlg =>
hashAlg -> Either JwtError ByteString
go SHA512
SHA512
    JwsAlg
_     -> JwtError -> Either JwtError ByteString
forall a b. a -> Either a b
Left (JwtError -> Either JwtError ByteString)
-> (String -> JwtError) -> String -> Either JwtError ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> JwtError
BadAlgorithm (Text -> JwtError) -> (String -> Text) -> String -> JwtError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack (String -> Either JwtError ByteString)
-> String -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ String
"Not an RSA algorithm: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ JwsAlg -> String
forall a. Show a => a -> String
show JwsAlg
a
  where
    go :: hashAlg -> Either JwtError ByteString
go hashAlg
h = (Error -> Either JwtError ByteString)
-> (ByteString -> Either JwtError ByteString)
-> Either Error ByteString
-> Either JwtError ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either JwtError ByteString -> Error -> Either JwtError ByteString
forall a b. a -> b -> a
const (Either JwtError ByteString -> Error -> Either JwtError ByteString)
-> Either JwtError ByteString
-> Error
-> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError ByteString
forall a b. a -> Either a b
Left JwtError
BadCrypto) ByteString -> Either JwtError ByteString
forall a b. b -> Either a b
Right (Either Error ByteString -> Either JwtError ByteString)
-> Either Error ByteString -> Either JwtError ByteString
forall a b. (a -> b) -> a -> b
$ Maybe Blinder
-> Maybe hashAlg
-> PrivateKey
-> ByteString
-> Either Error ByteString
forall hashAlg.
HashAlgorithmASN1 hashAlg =>
Maybe Blinder
-> Maybe hashAlg
-> PrivateKey
-> ByteString
-> Either Error ByteString
PKCS15.sign Maybe Blinder
blinder (hashAlg -> Maybe hashAlg
forall a. a -> Maybe a
Just hashAlg
h) PrivateKey
key ByteString
msg

-- | Verify the signature for a message using an RSA public key.
--
-- Returns false if the check fails or if the 'Alg' value is not
-- an RSA signature algorithm.
rsaVerify :: JwsAlg        -- ^ The signature algorithm. Used to obtain the hash function.
          -> RSA.PublicKey -- ^ The key to check the signature with
          -> ByteString    -- ^ The message/content
          -> ByteString    -- ^ The signature to check
          -> Bool          -- ^ Whether the signature is correct
rsaVerify :: JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
rsaVerify JwsAlg
a PublicKey
key ByteString
msg ByteString
sig = case JwsAlg
a of
    JwsAlg
RS256 -> SHA256 -> Bool
forall hashAlg. HashAlgorithmASN1 hashAlg => hashAlg -> Bool
go SHA256
SHA256
    JwsAlg
RS384 -> SHA384 -> Bool
forall hashAlg. HashAlgorithmASN1 hashAlg => hashAlg -> Bool
go SHA384
SHA384
    JwsAlg
RS512 -> SHA512 -> Bool
forall hashAlg. HashAlgorithmASN1 hashAlg => hashAlg -> Bool
go SHA512
SHA512
    JwsAlg
_     -> Bool
False
  where
    go :: hashAlg -> Bool
go hashAlg
h = Maybe hashAlg -> PublicKey -> ByteString -> ByteString -> Bool
forall hashAlg.
HashAlgorithmASN1 hashAlg =>
Maybe hashAlg -> PublicKey -> ByteString -> ByteString -> Bool
PKCS15.verify (hashAlg -> Maybe hashAlg
forall a. a -> Maybe a
Just hashAlg
h) PublicKey
key ByteString
msg ByteString
sig

-- | Verify the signature for a message using an EC public key.
--
-- Returns false if the check fails or if the 'Alg' value is not
-- an EC signature algorithm.
ecVerify :: JwsAlg          -- ^ The signature algorithm. Used to obtain the hash function.
         -> ECDSA.PublicKey -- ^ The key to check the signature with
         -> ByteString      -- ^ The message/content
         -> ByteString      -- ^ The signature to check
         -> Bool            -- ^ Whether the signature is correct
ecVerify :: JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
ecVerify JwsAlg
a PublicKey
key ByteString
msg ByteString
sig = case JwsAlg
a of
    JwsAlg
ES256 -> SHA256 -> Bool
forall hash. HashAlgorithm hash => hash -> Bool
go SHA256
SHA256
    JwsAlg
ES384 -> SHA384 -> Bool
forall hash. HashAlgorithm hash => hash -> Bool
go SHA384
SHA384
    JwsAlg
ES512 -> SHA512 -> Bool
forall hash. HashAlgorithm hash => hash -> Bool
go SHA512
SHA512
    JwsAlg
_     -> Bool
False
  where
    (ByteString
r, ByteString
s) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
sig Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ByteString
sig
    ecSig :: Signature
ecSig  = Integer -> Integer -> Signature
ECDSA.Signature (ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
r) (ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
s)
    go :: hash -> Bool
go hash
h   = hash -> PublicKey -> Signature -> ByteString -> Bool
forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
hash -> PublicKey -> Signature -> msg -> Bool
ECDSA.verify hash
h PublicKey
key Signature
ecSig ByteString
msg

-- | Generates the symmetric key (content management key) and IV
--
-- Used to encrypt a message.
generateCmkAndIV :: MonadRandom m
    => Enc
    -- ^ The encryption algorithm to be used
    -> m (ScrubbedBytes, ScrubbedBytes)
    -- ^ The key, IV
generateCmkAndIV :: Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
e = do
    ScrubbedBytes
cmk <- Int -> m ScrubbedBytes
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes (Enc -> Int
forall p. Num p => Enc -> p
keySize Enc
e)
    ScrubbedBytes
iv  <- Int -> m ScrubbedBytes
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes (Enc -> Int
forall p. Num p => Enc -> p
ivSize Enc
e)   -- iv for aes gcm or cbc
    (ScrubbedBytes, ScrubbedBytes) -> m (ScrubbedBytes, ScrubbedBytes)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScrubbedBytes
cmk, ScrubbedBytes
iv)
  where
    keySize :: Enc -> p
keySize Enc
A128GCM = p
16
    keySize Enc
A192GCM = p
24
    keySize Enc
A256GCM = p
32
    keySize Enc
A128CBC_HS256 = p
32
    keySize Enc
A192CBC_HS384 = p
48
    keySize Enc
A256CBC_HS512 = p
64

    ivSize :: Enc -> p
ivSize Enc
A128GCM = p
12
    ivSize Enc
A192GCM = p
12
    ivSize Enc
A256GCM = p
12
    ivSize Enc
_       = p
16

-- | Encrypts a message (typically a symmetric key) using RSA.
rsaEncrypt :: (MonadRandom m, ByteArray msg, ByteArray out)
    => RSA.PublicKey
    -- ^ The encryption key
    -> JweAlg
    -- ^ The algorithm (@RSA1_5@, @RSA_OAEP@, or @RSA_OAEP_256@)
    -> msg
    -- ^ The message to encrypt
    -> m (Either JwtError out)
    -- ^ The encrypted message
rsaEncrypt :: PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
k JweAlg
a msg
msg = (ByteString -> out)
-> Either JwtError ByteString -> Either JwtError out
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> out
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Either JwtError ByteString -> Either JwtError out)
-> m (Either JwtError ByteString) -> m (Either JwtError out)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case JweAlg
a of
    JweAlg
RSA1_5       -> m (Either Error ByteString) -> m (Either JwtError ByteString)
forall b c. m (Either b c) -> m (Either JwtError c)
mapErr (PublicKey -> ByteString -> m (Either Error ByteString)
forall (m :: * -> *).
MonadRandom m =>
PublicKey -> ByteString -> m (Either Error ByteString)
PKCS15.encrypt PublicKey
k ByteString
bs)
    JweAlg
RSA_OAEP     -> m (Either Error ByteString) -> m (Either JwtError ByteString)
forall b c. m (Either b c) -> m (Either JwtError c)
mapErr (OAEPParams SHA1 ByteString ByteString
-> PublicKey -> ByteString -> m (Either Error ByteString)
forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PublicKey -> ByteString -> m (Either Error ByteString)
OAEP.encrypt (SHA1 -> OAEPParams SHA1 ByteString ByteString
forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA1
SHA1) PublicKey
k ByteString
bs)
    JweAlg
RSA_OAEP_256 -> m (Either Error ByteString) -> m (Either JwtError ByteString)
forall b c. m (Either b c) -> m (Either JwtError c)
mapErr (OAEPParams SHA256 ByteString ByteString
-> PublicKey -> ByteString -> m (Either Error ByteString)
forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PublicKey -> ByteString -> m (Either Error ByteString)
OAEP.encrypt (SHA256 -> OAEPParams SHA256 ByteString ByteString
forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA256
SHA256) PublicKey
k ByteString
bs)
    JweAlg
_            -> Either JwtError ByteString -> m (Either JwtError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtError -> Either JwtError ByteString
forall a b. a -> Either a b
Left (Text -> JwtError
BadAlgorithm Text
"Not an RSA algorithm"))
  where
    bs :: ByteString
bs = msg -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert msg
msg
    mapErr :: m (Either b c) -> m (Either JwtError c)
mapErr = (Either b c -> Either JwtError c)
-> m (Either b c) -> m (Either JwtError c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((b -> JwtError) -> Either b c -> Either JwtError c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (JwtError -> b -> JwtError
forall a b. a -> b -> a
const JwtError
BadCrypto))

-- | Decrypts an RSA encrypted message.
rsaDecrypt :: ByteArray ct
    => Maybe RSA.Blinder
    -> RSA.PrivateKey
    -- ^ The decryption key
    -> JweAlg
    -- ^ The RSA algorithm to use
    -> ct
    -- ^ The encrypted content
    -> Either JwtError ScrubbedBytes
    -- ^ The decrypted key
rsaDecrypt :: Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt Maybe Blinder
blinder PrivateKey
rsaKey JweAlg
a ct
ct = ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ScrubbedBytes)
-> Either JwtError ByteString -> Either JwtError ScrubbedBytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case JweAlg
a of
    JweAlg
RSA1_5       -> Either Error ByteString -> Either JwtError ByteString
forall b c. Either b c -> Either JwtError c
mapErr (Maybe Blinder
-> PrivateKey -> ByteString -> Either Error ByteString
PKCS15.decrypt Maybe Blinder
blinder PrivateKey
rsaKey ByteString
bs)
    JweAlg
RSA_OAEP     -> Either Error ByteString -> Either JwtError ByteString
forall b c. Either b c -> Either JwtError c
mapErr (Maybe Blinder
-> OAEPParams SHA1 ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
forall hash.
HashAlgorithm hash =>
Maybe Blinder
-> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
OAEP.decrypt Maybe Blinder
blinder (SHA1 -> OAEPParams SHA1 ByteString ByteString
forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA1
SHA1) PrivateKey
rsaKey ByteString
bs)
    JweAlg
RSA_OAEP_256 -> Either Error ByteString -> Either JwtError ByteString
forall b c. Either b c -> Either JwtError c
mapErr (Maybe Blinder
-> OAEPParams SHA256 ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
forall hash.
HashAlgorithm hash =>
Maybe Blinder
-> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
OAEP.decrypt Maybe Blinder
blinder (SHA256 -> OAEPParams SHA256 ByteString ByteString
forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA256
SHA256) PrivateKey
rsaKey ByteString
bs)
    JweAlg
_            -> JwtError -> Either JwtError ByteString
forall a b. a -> Either a b
Left (Text -> JwtError
BadAlgorithm Text
"Not an RSA algorithm")
  where
    bs :: ByteString
bs = ct -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ct
ct
    mapErr :: Either b c -> Either JwtError c
mapErr = (b -> JwtError) -> Either b c -> Either JwtError c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (JwtError -> b -> JwtError
forall a b. a -> b -> a
const JwtError
BadCrypto)

-- Dummy type to constrain Cipher type
data C c = C

initCipher :: BlockCipher c => C c -> ScrubbedBytes -> Either JwtError c
initCipher :: C c -> ScrubbedBytes -> Either JwtError c
initCipher C c
_ ScrubbedBytes
k = CryptoFailable c -> Either JwtError c
forall a. CryptoFailable a -> Either JwtError a
mapFail (ScrubbedBytes -> CryptoFailable c
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ScrubbedBytes
k)

-- Map CryptoFailable to JwtError
mapFail :: CryptoFailable a -> Either JwtError a
mapFail :: CryptoFailable a -> Either JwtError a
mapFail (CryptoPassed a
a) = a -> Either JwtError a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
mapFail (CryptoFailed CryptoError
e) = JwtError -> Either JwtError a
forall a b. a -> Either a b
Left (JwtError -> Either JwtError a) -> JwtError -> Either JwtError a
forall a b. (a -> b) -> a -> b
$ case CryptoError
e of
    CryptoError
CryptoError_KeySizeInvalid -> Text -> JwtError
KeyError Text
"cipher key length is invalid"
    CryptoError
_ -> JwtError
BadCrypto


-- | Decrypt an AES encrypted message.
decryptPayload :: forall ba. (ByteArray ba)
    => Enc
    -- ^ Encryption algorithm
    -> ScrubbedBytes
    -- ^ Content encryption key
    -> IV
    -- ^ IV
    -> ba
    -- ^ Additional authentication data
    -> Tag
    -- ^ The integrity protection value to be checked
    -> ba
    -- ^ The encrypted JWT payload
    -> Maybe ba
decryptPayload :: Enc -> ScrubbedBytes -> IV -> ba -> Tag -> ba -> Maybe ba
decryptPayload Enc
enc ScrubbedBytes
cek IV
iv_ ba
aad Tag
tag_ ba
ct = case (Enc
enc, IV
iv_, Tag
tag_) of
    (Enc
A128GCM, IV12 ByteString
b, Tag16 ByteString
t) -> C AES128 -> ByteString -> ByteString -> Maybe ba
forall c.
BlockCipher c =>
C c -> ByteString -> ByteString -> Maybe ba
doGCM (C AES128
forall c. C c
C :: C AES128) ByteString
b ByteString
t
    (Enc
A192GCM, IV12 ByteString
b, Tag16 ByteString
t) -> C AES192 -> ByteString -> ByteString -> Maybe ba
forall c.
BlockCipher c =>
C c -> ByteString -> ByteString -> Maybe ba
doGCM (C AES192
forall c. C c
C :: C AES192) ByteString
b ByteString
t
    (Enc
A256GCM, IV12 ByteString
b, Tag16 ByteString
t) -> C AES256 -> ByteString -> ByteString -> Maybe ba
forall c.
BlockCipher c =>
C c -> ByteString -> ByteString -> Maybe ba
doGCM (C AES256
forall c. C c
C :: C AES256) ByteString
b ByteString
t
    (Enc
A128CBC_HS256, IV16 ByteString
b, Tag16 ByteString
t) -> C AES128 -> ByteString -> ByteString -> SHA256 -> Int -> Maybe ba
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
doCBC (C AES128
forall c. C c
C :: C AES128) ByteString
b ByteString
t SHA256
SHA256 Int
16
    (Enc
A192CBC_HS384, IV16 ByteString
b, Tag24 ByteString
t) -> C AES192 -> ByteString -> ByteString -> SHA384 -> Int -> Maybe ba
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
doCBC (C AES192
forall c. C c
C :: C AES192) ByteString
b ByteString
t SHA384
SHA384 Int
24
    (Enc
A256CBC_HS512, IV16 ByteString
b, Tag32 ByteString
t) -> C AES256 -> ByteString -> ByteString -> SHA512 -> Int -> Maybe ba
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
doCBC (C AES256
forall c. C c
C :: C AES256) ByteString
b ByteString
t SHA512
SHA512 Int
32
    (Enc, IV, Tag)
_ -> Maybe ba
forall a. Maybe a
Nothing -- This shouldn't be possible if the JWT was parsed first
  where
    (ScrubbedBytes
cbcMacKey, ScrubbedBytes
cbcEncKey) = Int -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
BA.splitAt (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
cek Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ScrubbedBytes
cek :: (ScrubbedBytes, ScrubbedBytes)
    al :: Word64
al = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
aad) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
8 :: Word64

    doGCM :: BlockCipher c => C c -> ByteString -> ByteString -> Maybe ba
    doGCM :: C c -> ByteString -> ByteString -> Maybe ba
doGCM C c
c ByteString
iv ByteString
tag = do
        c
cipher <- Either JwtError c -> Maybe c
forall a b. Either a b -> Maybe b
rightToMaybe (C c -> ScrubbedBytes -> Either JwtError c
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C c
c ScrubbedBytes
cek)
        AEAD c
aead <- CryptoFailable (AEAD c) -> Maybe (AEAD c)
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (AEADMode -> c -> ByteString -> CryptoFailable (AEAD c)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM c
cipher ByteString
iv)
        AEAD c -> ba -> ba -> AuthTag -> Maybe ba
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD c
aead ba
aad ba
ct (Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag)

    doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
    doCBC :: C c -> ByteString -> ByteString -> a -> Int -> Maybe ba
doCBC C c
c ByteString
iv ByteString
tag a
a Int
tagLen = do
        a -> ByteString -> ByteString -> Int -> Maybe ()
forall a.
HashAlgorithm a =>
a -> ByteString -> ByteString -> Int -> Maybe ()
checkMac a
a ByteString
tag ByteString
iv Int
tagLen
        c
cipher <- Either JwtError c -> Maybe c
forall a b. Either a b -> Maybe b
rightToMaybe (C c -> ScrubbedBytes -> Either JwtError c
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C c
c ScrubbedBytes
cbcEncKey)
        IV c
iv'    <- ByteString -> Maybe (IV c)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
iv
        Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
ct Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` c -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize c
cipher Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) Maybe ()
forall a. Maybe a
Nothing
        ba -> Maybe ba
forall ba. ByteArray ba => ba -> Maybe ba
unpad (ba -> Maybe ba) -> ba -> Maybe ba
forall a b. (a -> b) -> a -> b
$ c -> IV c -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcDecrypt c
cipher IV c
iv' ba
ct

    checkMac :: HashAlgorithm a => a -> ByteString -> ByteString -> Int -> Maybe ()
    checkMac :: a -> ByteString -> ByteString -> Int -> Maybe ()
checkMac a
a ByteString
tag ByteString
iv Int
l = do
        let mac :: Bytes
mac = Int -> Bytes -> Bytes
forall bs. ByteArray bs => Int -> bs -> bs
BA.take Int
l (Bytes -> Bytes) -> Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ HMAC a -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (HMAC a -> Bytes) -> HMAC a -> Bytes
forall a b. (a -> b) -> a -> b
$ a -> ByteString -> HMAC a
forall a. HashAlgorithm a => a -> ByteString -> HMAC a
doMac a
a ByteString
iv :: BA.Bytes
        Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
tag ByteString -> Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`BA.constEq` Bytes
mac) Maybe ()
forall a. Maybe a
Nothing

    doMac :: HashAlgorithm a => a -> ByteString -> HMAC a
    doMac :: a -> ByteString -> HMAC a
doMac a
_ ByteString
iv = ScrubbedBytes -> ByteString -> HMAC a
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ScrubbedBytes
cbcMacKey ([ByteString] -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
aad, ByteString
iv, ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, Word64 -> ByteString
forall a. Serialize a => a -> ByteString
Serialize.encode Word64
al] :: ByteString)

-- | Encrypt a message using AES.
encryptPayload :: forall ba iv. (ByteArray ba, ByteArray iv)
    => Enc
    -- ^ Encryption algorithm
    -> ScrubbedBytes
    -- ^ Content management key
    -> iv
    -- ^ IV
    -> ba
    -- ^ Additional authenticated data
    -> ba
    -- ^ The message/JWT claims
    -> Maybe (AuthTag, ba)
    -- ^ Ciphertext claims and signature tag
encryptPayload :: Enc -> ScrubbedBytes -> iv -> ba -> ba -> Maybe (AuthTag, ba)
encryptPayload Enc
e ScrubbedBytes
cek iv
iv ba
aad ba
msg = case Enc
e of
    Enc
A128GCM       -> C AES128 -> Maybe (AuthTag, ba)
forall a. BlockCipher a => C a -> Maybe (AuthTag, ba)
doGCM (C AES128
forall c. C c
C :: C AES128)
    Enc
A192GCM       -> C AES192 -> Maybe (AuthTag, ba)
forall a. BlockCipher a => C a -> Maybe (AuthTag, ba)
doGCM (C AES192
forall c. C c
C :: C AES192)
    Enc
A256GCM       -> C AES256 -> Maybe (AuthTag, ba)
forall a. BlockCipher a => C a -> Maybe (AuthTag, ba)
doGCM (C AES256
forall c. C c
C :: C AES256)
    Enc
A128CBC_HS256 -> C AES128 -> SHA256 -> Int -> Maybe (AuthTag, ba)
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> a -> Int -> Maybe (AuthTag, ba)
doCBC (C AES128
forall c. C c
C :: C AES128) SHA256
SHA256 Int
16
    Enc
A192CBC_HS384 -> C AES192 -> SHA384 -> Int -> Maybe (AuthTag, ba)
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> a -> Int -> Maybe (AuthTag, ba)
doCBC (C AES192
forall c. C c
C :: C AES192) SHA384
SHA384 Int
24
    Enc
A256CBC_HS512 -> C AES256 -> SHA512 -> Int -> Maybe (AuthTag, ba)
forall a c.
(HashAlgorithm a, BlockCipher c) =>
C c -> a -> Int -> Maybe (AuthTag, ba)
doCBC (C AES256
forall c. C c
C :: C AES256) SHA512
SHA512 Int
32
  where
    (ScrubbedBytes
cbcMacKey, ScrubbedBytes
cbcEncKey) = Int -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
BA.splitAt (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
cek Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ScrubbedBytes
cek :: (ScrubbedBytes, ScrubbedBytes)
    al :: Word64
al = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
aad) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
8 :: Word64

    doGCM :: C a -> Maybe (AuthTag, ba)
doGCM C a
c = do
        a
cipher <- Either JwtError a -> Maybe a
forall a b. Either a b -> Maybe b
rightToMaybe (C a -> ScrubbedBytes -> Either JwtError a
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C a
c ScrubbedBytes
cek)
        AEAD a
aead <- CryptoFailable (AEAD a) -> Maybe (AEAD a)
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (AEADMode -> a -> iv -> CryptoFailable (AEAD a)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM a
cipher iv
iv)
        (AuthTag, ba) -> Maybe (AuthTag, ba)
forall (m :: * -> *) a. Monad m => a -> m a
return ((AuthTag, ba) -> Maybe (AuthTag, ba))
-> (AuthTag, ba) -> Maybe (AuthTag, ba)
forall a b. (a -> b) -> a -> b
$ AEAD a -> ba -> ba -> Int -> (AuthTag, ba)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
aeadSimpleEncrypt AEAD a
aead ba
aad ba
msg Int
16

    doCBC :: (HashAlgorithm a, BlockCipher c) => C c -> a -> Int -> Maybe (AuthTag, ba)
    doCBC :: C c -> a -> Int -> Maybe (AuthTag, ba)
doCBC C c
c a
a Int
tagLen = do
        c
cipher <- Either JwtError c -> Maybe c
forall a b. Either a b -> Maybe b
rightToMaybe (C c -> ScrubbedBytes -> Either JwtError c
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C c
c ScrubbedBytes
cbcEncKey)
        IV c
iv'    <- iv -> Maybe (IV c)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV iv
iv
        let ct :: ba
ct = c -> IV c -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcEncrypt c
cipher IV c
iv' (ba -> ba
forall ba. ByteArray ba => ba -> ba
pad ba
msg)
            mac :: HMAC a
mac = a -> ba -> HMAC a
forall a. HashAlgorithm a => a -> ba -> HMAC a
doMac a
a ba
ct
            tag :: Bytes
tag = Int -> Bytes -> Bytes
forall bs. ByteArray bs => Int -> bs -> bs
BA.take Int
tagLen (HMAC a -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert HMAC a
mac)
        (AuthTag, ba) -> Maybe (AuthTag, ba)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bytes -> AuthTag
AuthTag Bytes
tag, ba
ct)

    doMac :: HashAlgorithm a => a -> ba -> HMAC a
    doMac :: a -> ba -> HMAC a
doMac a
_ ba
ct = ScrubbedBytes -> ByteString -> HMAC a
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ScrubbedBytes
cbcMacKey ([ByteString] -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
aad, iv -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert iv
iv, ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, Word64 -> ByteString
forall a. Serialize a => a -> ByteString
Serialize.encode Word64
al] :: ByteString)

unpad :: (ByteArray ba) => ba -> Maybe ba
unpad :: ba -> Maybe ba
unpad ba
bs
    | Int
padLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
16 Bool -> Bool -> Bool
|| Int
padLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
padding = Maybe ba
forall a. Maybe a
Nothing
    | (Word8 -> Bool) -> ba -> Bool
forall ba. ByteArrayAccess ba => (Word8 -> Bool) -> ba -> Bool
BA.any (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
padByte) ba
padding = Maybe ba
forall a. Maybe a
Nothing
    | Bool
otherwise = ba -> Maybe ba
forall (m :: * -> *) a. Monad m => a -> m a
return ba
pt
  where
    len :: Int
len     = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
bs
    padByte :: Word8
padByte = ba -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
BA.index ba
bs (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
    padLen :: Int
padLen  = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padByte
    (ba
pt, ba
padding) = Int -> ba -> (ba, ba)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
BA.splitAt (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
padLen) ba
bs

pad ::  (ByteArray ba) => ba -> ba
pad :: ba -> ba
pad ba
bs = ba -> ba -> ba
forall bs. ByteArray bs => bs -> bs -> bs
BA.append ba
bs ba
padding
  where
    lastBlockSize :: Int
lastBlockSize = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
bs Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
16
    padByte :: Word8
padByte       = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lastBlockSize :: Word8
    padding :: ba
padding       = Int -> Word8 -> ba
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padByte) Word8
padByte


-- Key wrapping and unwrapping functions

-- | <https://tools.ietf.org/html/rfc3394#section-2.2.1>
keyWrap :: ByteArray ba => JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap :: JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap JweAlg
alg ScrubbedBytes
kek ScrubbedBytes
cek = case JweAlg
alg of
    JweAlg
A128KW -> C AES128 -> Either JwtError ba
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doKeyWrap (C AES128
forall c. C c
C :: C AES128)
    JweAlg
A192KW -> C AES192 -> Either JwtError ba
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doKeyWrap (C AES192
forall c. C c
C :: C AES192)
    JweAlg
A256KW -> C AES256 -> Either JwtError ba
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doKeyWrap (C AES256
forall c. C c
C :: C AES256)
    JweAlg
_      -> JwtError -> Either JwtError ba
forall a b. a -> Either a b
Left (Text -> JwtError
BadAlgorithm Text
"Not a keywrap algorithm")
  where
    l :: Int
l = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
cek
    n :: Int
n = Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    iv :: ByteString
iv = Int -> Word8 -> ByteString
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate Int
8 Word8
166 :: ByteString

    doKeyWrap :: C a -> Either JwtError b
doKeyWrap C a
c = do
        Bool -> Either JwtError () -> Either JwtError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
16 Bool -> Bool -> Bool
|| Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (JwtError -> Either JwtError ()
forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Invalid content key"))
        a
cipher <- C a -> ScrubbedBytes -> Either JwtError a
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C a
c ScrubbedBytes
kek
        let p :: [ScrubbedBytes]
p = ScrubbedBytes -> [ScrubbedBytes]
forall ba. ByteArray ba => ba -> [ba]
toBlocks ScrubbedBytes
cek
            (ScrubbedBytes
r0, [ScrubbedBytes]
r) = ((ScrubbedBytes, [ScrubbedBytes])
 -> Int -> (ScrubbedBytes, [ScrubbedBytes]))
-> (ScrubbedBytes, [ScrubbedBytes])
-> [Int]
-> (ScrubbedBytes, [ScrubbedBytes])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((ScrubbedBytes -> ScrubbedBytes)
-> Int
-> (ScrubbedBytes, [ScrubbedBytes])
-> Int
-> (ScrubbedBytes, [ScrubbedBytes])
forall t a.
(ByteArray t, ByteArray a) =>
(t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound (a -> ScrubbedBytes -> ScrubbedBytes
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt a
cipher) Int
1) (ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
iv, [ScrubbedBytes]
p) [Int
0..Int
5]
        b -> Either JwtError b
forall a b. b -> Either a b
Right (b -> Either JwtError b) -> b -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ [ScrubbedBytes] -> b
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat (ScrubbedBytes
r0 ScrubbedBytes -> [ScrubbedBytes] -> [ScrubbedBytes]
forall a. a -> [a] -> [a]
: [ScrubbedBytes]
r)

    doRound :: (t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound t -> a
_ Int
_  (a
a, []) Int
_ = (a
a, [])
    doRound t -> a
enc Int
i (a
a, a
r:[a]
rs) Int
j =
        let b :: a
b  = t -> a
enc (t -> a) -> t -> a
forall a b. (a -> b) -> a -> b
$ [a] -> t
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [a
a, a
r]
            t :: Word8
t  = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) :: Word8
            a' :: a
a' = Word8 -> a -> a
forall ba. ByteArray ba => Word8 -> ba -> ba
txor Word8
t (Int -> a -> a
forall bs. ByteArray bs => Int -> bs -> bs
BA.take Int
8 a
b)
            r' :: a
r' = Int -> a -> a
forall bs. ByteArray bs => Int -> bs -> bs
BA.drop Int
8 a
b
            next :: (a, [a])
next = (t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound t -> a
enc (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
a', [a]
rs) Int
j
        in ((a, [a]) -> a
forall a b. (a, b) -> a
fst (a, [a])
next, a
r' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a, [a]) -> [a]
forall a b. (a, b) -> b
snd (a, [a])
next)

txor :: ByteArray ba => Word8 -> ba -> ba
txor :: Word8 -> ba -> ba
txor Word8
t ba
b =
    let n :: Int
n = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
b
        lastByte :: Word8
lastByte = ba -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
BA.index ba
b (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        initBytes :: ba
initBytes = Int -> ba -> ba
forall bs. ByteArray bs => Int -> bs -> bs
BA.take (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ba
b
      in ba -> Word8 -> ba
forall a. ByteArray a => a -> Word8 -> a
BA.snoc ba
initBytes (Word8
lastByte Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
t)

toBlocks :: ByteArray ba => ba -> [ba]
toBlocks :: ba -> [ba]
toBlocks ba
bytes
    | ba -> Bool
forall a. ByteArrayAccess a => a -> Bool
BA.null ba
bytes = []
    | Bool
otherwise = let (ba
b, ba
bs') = Int -> ba -> (ba, ba)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
BA.splitAt Int
8 ba
bytes
                   in ba
b ba -> [ba] -> [ba]
forall a. a -> [a] -> [a]
: ba -> [ba]
forall ba. ByteArray ba => ba -> [ba]
toBlocks ba
bs'

keyUnwrap :: ByteArray ba => ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap :: ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap ScrubbedBytes
kek JweAlg
alg ba
encK = case JweAlg
alg of
    JweAlg
A128KW -> C AES128 -> Either JwtError ScrubbedBytes
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doUnWrap (C AES128
forall c. C c
C :: C AES128)
    JweAlg
A192KW -> C AES192 -> Either JwtError ScrubbedBytes
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doUnWrap (C AES192
forall c. C c
C :: C AES192)
    JweAlg
A256KW -> C AES256 -> Either JwtError ScrubbedBytes
forall a b.
(ByteArray b, BlockCipher a) =>
C a -> Either JwtError b
doUnWrap (C AES256
forall c. C c
C :: C AES256)
    JweAlg
_      -> JwtError -> Either JwtError ScrubbedBytes
forall a b. a -> Either a b
Left (Text -> JwtError
BadAlgorithm Text
"Not a keywrap algorithm")
  where
    l :: Int
l = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
encK
    n :: Int
n = (Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    iv :: ba
iv = Int -> Word8 -> ba
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate Int
8 Word8
166

    doUnWrap :: C a -> Either JwtError b
doUnWrap C a
c = do
        Bool -> Either JwtError () -> Either JwtError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
24 Bool -> Bool -> Bool
|| Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (JwtError -> Either JwtError ()
forall a b. a -> Either a b
Left JwtError
BadCrypto)
        a
cipher <- C a -> ScrubbedBytes -> Either JwtError a
forall c.
BlockCipher c =>
C c -> ScrubbedBytes -> Either JwtError c
initCipher C a
c ScrubbedBytes
kek
        let r :: [ba]
r = ba -> [ba]
forall ba. ByteArray ba => ba -> [ba]
toBlocks ba
encK
            (ba
p0, [ba]
p) = ((ba, [ba]) -> Int -> (ba, [ba]))
-> (ba, [ba]) -> [Int] -> (ba, [ba])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((ba -> ba) -> Int -> (ba, [ba]) -> Int -> (ba, [ba])
forall t a.
(ByteArray t, ByteArray a) =>
(t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound (a -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbDecrypt a
cipher) Int
n) ([ba] -> ba
forall a. [a] -> a
head [ba]
r, [ba] -> [ba]
forall a. [a] -> [a]
reverse ([ba] -> [ba]
forall a. [a] -> [a]
tail [ba]
r)) ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
0..Int
5])
        Bool -> Either JwtError () -> Either JwtError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ba
p0 ba -> ba -> Bool
forall a. Eq a => a -> a -> Bool
== ba
iv) (JwtError -> Either JwtError ()
forall a b. a -> Either a b
Left JwtError
BadCrypto)
        b -> Either JwtError b
forall a b. b -> Either a b
Right (b -> Either JwtError b) -> b -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ [ba] -> b
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat ([ba] -> [ba]
forall a. [a] -> [a]
reverse [ba]
p)

    doRound :: (t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound t -> a
_ Int
_  (a
a, []) Int
_ = (a
a, [])
    doRound t -> a
dec Int
i (a
a, a
r:[a]
rs) Int
j =
        let b :: a
b  = t -> a
dec (t -> a) -> t -> a
forall a b. (a -> b) -> a -> b
$ [a] -> t
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [Word8 -> a -> a
forall ba. ByteArray ba => Word8 -> ba -> ba
txor Word8
t a
a, a
r]
            t :: Word8
t  = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) :: Word8
            a' :: a
a' = Int -> a -> a
forall bs. ByteArray bs => Int -> bs -> bs
BA.take Int
8 a
b
            r' :: a
r' = Int -> a -> a
forall bs. ByteArray bs => Int -> bs -> bs
BA.drop Int
8 a
b
            next :: (a, [a])
next = (t -> a) -> Int -> (a, [a]) -> Int -> (a, [a])
doRound t -> a
dec (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (a
a', [a]
rs) Int
j
        in ((a, [a]) -> a
forall a b. (a, b) -> a
fst (a, [a])
next, a
r' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a, [a]) -> [a]
forall a b. (a, b) -> b
snd (a, [a])
next)