{-# LANGUAGE OverloadedStrings #-}
module Jose.Jwe
( jwkEncode
, jwkDecode
, rsaEncode
, rsaDecode
)
where
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
import qualified Jose.Internal.Parser as P
jwkEncode :: MonadRandom m
=> JweAlg
-> Enc
-> Jwk
-> Payload
-> m (Either JwtError Jwt)
jwkEncode a e jwk payload = runExceptT $ case jwk of
RsaPublicJwk kPub kid _ _ -> doEncode (hdr kid) (doRsa kPub) bytes
RsaPrivateJwk kPr kid _ _ -> doEncode (hdr kid) (doRsa (private_pub kPr)) bytes
SymmetricJwk kek kid _ _ -> doEncode (hdr kid) (ExceptT . return . keyWrap a (BA.convert kek)) bytes
_ -> throwE $ KeyError "JWK cannot encode a JWE"
where
doRsa kPub = ExceptT . rsaEncrypt kPub a
hdr kid = defJweHdr {jweAlg = a, jweEnc = e, jweKid = kid, jweCty = contentType}
(contentType, bytes) = case payload of
Claims c -> (Nothing, c)
Nested (Jwt b) -> (Just "JWT", b)
jwkDecode :: MonadRandom m
=> Jwk
-> ByteString
-> m (Either JwtError JwtContent)
jwkDecode jwk jwt = runExceptT $ case jwk of
RsaPrivateJwk kPr _ _ _ -> do
blinder <- lift $ generateBlinder (public_n $ private_pub kPr)
e <- doDecode (rsaDecrypt (Just blinder) kPr) jwt
return (Jwe e)
SymmetricJwk kb _ _ _ -> fmap Jwe (doDecode (keyUnwrap (BA.convert kb)) jwt)
_ -> throwE $ KeyError "JWK cannot decode a JWE"
doDecode :: MonadRandom m
=> (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString
-> ExceptT JwtError m Jwe
doDecode decodeCek jwt = do
encodedJwt <- ExceptT (return (P.parseJwt jwt))
case encodedJwt of
P.DecodableJwe hdr (P.EncryptedCEK ek) iv (P.Payload payload) tag (P.AAD aad) -> do
let alg = jweAlg hdr
enc = jweEnc hdr
(dummyCek, _) <- lift $ generateCmkAndIV enc
let decryptedCek = either (const dummyCek) id $ decodeCek alg ek
cek = if BA.length decryptedCek == BA.length dummyCek
then decryptedCek
else dummyCek
claims <- maybe (throwE BadCrypto) return $ decryptPayload enc cek iv aad tag payload
return (hdr, claims)
_ -> throwE (BadHeader "Content is not a JWE")
doEncode :: (MonadRandom m, ByteArray ba)
=> JweHeader
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode h encryptKey claims = do
(cmk, iv) <- lift (generateCmkAndIV e)
let Just (AuthTag sig, ct) = encryptPayload e cmk iv aad claims
jweKey <- encryptKey cmk
let jwe = B.intercalate "." $ map B64.encode [hdr, jweKey, BA.convert iv, BA.convert ct, BA.convert sig]
return (Jwt jwe)
where
e = jweEnc h
hdr = encodeHeader h
aad = B64.encode hdr
rsaEncode :: MonadRandom m
=> JweAlg
-> Enc
-> PublicKey
-> ByteString
-> m (Either JwtError Jwt)
rsaEncode a e kPub claims = runExceptT $ doEncode (defJweHdr {jweAlg = a, jweEnc = e}) (ExceptT . rsaEncrypt kPub a) claims
rsaDecode :: MonadRandom m
=> PrivateKey
-> ByteString
-> m (Either JwtError Jwe)
rsaDecode pk jwt = runExceptT $ do
blinder <- lift $ generateBlinder (public_n $ private_pub pk)
doDecode (rsaDecrypt (Just blinder) pk) jwt