module Jose.Jwe
( jwkEncode
, jwkDecode
, rsaEncode
, rsaDecode
)
where
import Control.Monad (unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Either
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
jwkEncode :: MonadRandom m
=> JweAlg
-> Enc
-> Jwk
-> Payload
-> m (Either JwtError Jwt)
jwkEncode a e jwk payload = runEitherT $ case jwk of
RsaPublicJwk kPub kid _ _ -> doEncode (hdr kid) (doRsa kPub) bytes
RsaPrivateJwk kPr kid _ _ -> doEncode (hdr kid) (doRsa (private_pub kPr)) bytes
SymmetricJwk kek kid _ _ -> doEncode (hdr kid) (hoistEither . keyWrap a kek) bytes
_ -> left $ KeyError "JWK cannot encode a JWE"
where
doRsa kPub = EitherT . rsaEncrypt kPub a
hdr kid = defJweHdr {jweAlg = a, jweEnc = e, jweKid = kid, jweCty = contentType}
(contentType, bytes) = case payload of
Claims c -> (Nothing, c)
Nested (Jwt b) -> (Just "JWT", b)
jwkDecode :: MonadRandom m
=> Jwk
-> ByteString
-> m (Either JwtError JwtContent)
jwkDecode jwk jwt = runEitherT $ case jwk of
RsaPrivateJwk kPr _ _ _ -> do
blinder <- lift $ generateBlinder (public_n $ private_pub kPr)
e <- doDecode (rsaDecrypt (Just blinder) kPr) jwt
return (Jwe e)
SymmetricJwk kb _ _ _ -> fmap Jwe (doDecode (keyUnwrap kb) jwt)
_ -> left $ KeyError "JWK cannot decode a JWE"
doDecode :: MonadRandom m
=> (JweAlg -> ByteString -> Either JwtError ByteString)
-> ByteString
-> EitherT JwtError m Jwe
doDecode decodeCek jwt = do
checkDots
let components = BC.split '.' jwt
let aad = head components
[h, ek, providedIv, payload, sig] <- mapM B64.decode components
hdr <- case parseHeader h of
Right (JweH jweHdr) -> return jweHdr
Right (JwsH _) -> left (BadHeader "Header is for a JWS")
Right UnsecuredH -> left (BadHeader "Header is for an unsecured JWT")
Left e -> left e
let alg = jweAlg hdr
enc = jweEnc hdr
(dummyCek, dummyIv) <- lift $ generateCmkAndIV enc
let decryptedCek = either (const dummyCek) id $ decodeCek alg ek
cek = if B.length decryptedCek == B.length dummyCek
then decryptedCek
else dummyCek
iv = if B.length providedIv == B.length dummyIv
then providedIv
else dummyIv
authTag = AuthTag $ BA.convert sig
claims <- maybe (left BadCrypto) return $ decryptPayload enc cek iv aad authTag payload
return (hdr, claims)
where
checkDots = unless (BC.count '.' jwt == 4) $ left (BadDots 4)
doEncode :: MonadRandom m
=> JweHeader
-> (ByteString -> EitherT JwtError m ByteString)
-> ByteString
-> EitherT JwtError m Jwt
doEncode h encryptKey claims = do
(cmk, iv) <- lift (generateCmkAndIV e)
let Just (AuthTag sig, ct) = encryptPayload e cmk iv aad claims
jweKey <- encryptKey cmk
let jwe = B.intercalate "." $ map B64.encode [hdr, jweKey, iv, ct, BA.convert sig]
return (Jwt jwe)
where
e = jweEnc h
hdr = encodeHeader h
aad = B64.encode hdr
rsaEncode :: MonadRandom m
=> JweAlg
-> Enc
-> PublicKey
-> ByteString
-> m (Either JwtError Jwt)
rsaEncode a e kPub claims = runEitherT $ doEncode (defJweHdr {jweAlg = a, jweEnc = e}) (EitherT . rsaEncrypt kPub a) claims
rsaDecode :: MonadRandom m
=> PrivateKey
-> ByteString
-> m (Either JwtError Jwe)
rsaDecode pk jwt = runEitherT $ do
blinder <- lift $ generateBlinder (public_n $ private_pub pk)
doDecode (rsaDecrypt (Just blinder) pk) jwt