{-# LANGUAGE OverloadedStrings, FlexibleContexts #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Jwt
( module Jose.Types
, encode
, decode
, decodeClaims
)
where
import Control.Monad (msum, when, unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (MonadRandom)
import Data.Aeson (decodeStrict',FromJSON)
import Data.ByteString (ByteString)
import Data.Maybe (isNothing)
import qualified Data.ByteString.Char8 as BC
import qualified Jose.Internal.Base64 as B64
import qualified Jose.Internal.Parser as P
import Jose.Types
import Jose.Jwk
import Jose.Jwa
import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe
encode :: MonadRandom m
=> [Jwk]
-> JwtEncoding
-> Payload
-> m (Either JwtError Jwt)
encode jwks encoding msg = runExceptT $ case encoding of
JwsEncoding None -> case msg of
Claims p -> return $ Jwt $ BC.intercalate "." [unsecuredHdr, B64.encode p]
Nested _ -> throwE BadClaims
JwsEncoding a -> case filter (canEncodeJws a) jwks of
[] -> throwE (KeyError "No matching key found for JWS algorithm")
(k:_) -> ExceptT . return =<< lift (Jws.jwkEncode a k msg)
JweEncoding a e -> case filter (canEncodeJwe a) jwks of
[] -> throwE (KeyError "No matching key found for JWE algorithm")
(k:_) -> ExceptT . return =<< lift (Jwe.jwkEncode a e k msg)
where
unsecuredHdr = B64.encode (BC.pack "{\"alg\":\"none\"}")
decode :: MonadRandom m
=> [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode keySet encoding jwt = runExceptT $ do
decodableJwt <- ExceptT (return (P.parseJwt jwt))
decodings <- case (decodableJwt, encoding) of
(P.Unsecured p, Just (JwsEncoding None)) -> return [Just (Unsecured p)]
(P.Unsecured _, _) -> throwE (BadAlgorithm "JWT is unsecured but expected 'alg' was not 'none'")
(P.DecodableJws hdr _ _ _, e) -> do
unless (isNothing e || e == Just (JwsEncoding (jwsAlg hdr))) $
throwE (BadAlgorithm "Expected 'alg' doesn't match JWS header")
ks <- checkKeys $ filter (canDecodeJws hdr) keySet
mapM decodeWithJws ks
(P.DecodableJwe hdr _ _ _ _ _, e) -> do
unless (isNothing e || e == Just (JweEncoding (jweAlg hdr) (jweEnc hdr))) $
throwE (BadAlgorithm "Expected encoding doesn't match JWE header")
ks <- checkKeys $ filter (canDecodeJwe hdr) keySet
mapM decodeWithJwe ks
case msum decodings of
Nothing -> throwE $ KeyError "None of the keys was able to decode the JWT"
Just jwtContent -> return jwtContent
where
decodeWithJws :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws k = either (const $ return Nothing) (return . Just . Jws) $ case k of
RsaPublicJwk kPub _ _ _ -> Jws.rsaDecode kPub jwt
RsaPrivateJwk kPr _ _ _ -> Jws.rsaDecode (private_pub kPr) jwt
EcPublicJwk kPub _ _ _ _ -> Jws.ecDecode kPub jwt
EcPrivateJwk kPr _ _ _ _ -> Jws.ecDecode (ECDSA.toPublicKey kPr) jwt
SymmetricJwk kb _ _ _ -> Jws.hmacDecode kb jwt
decodeWithJwe :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe k = fmap (either (const Nothing) Just) (lift (Jwe.jwkDecode k jwt))
checkKeys [] = throwE $ KeyError "No suitable key was found to decode the JWT"
checkKeys ks = return ks
decodeClaims :: (FromJSON a)
=> ByteString
-> Either JwtError (JwtHeader, a)
decodeClaims jwt = do
let components = BC.split '.' jwt
when (length components /= 3) $ Left $ BadDots 2
hdr <- B64.decode (head components) >>= parseHeader
claims <- B64.decode ((head . tail) components) >>= parseClaims
return (hdr, claims)
where
parseClaims bs = maybe (Left BadClaims) Right $ decodeStrict' bs