{-# LANGUAGE OverloadedStrings, FlexibleContexts #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Jwt
( module Jose.Types
, encode
, decode
, decodeClaims
)
where
import Control.Monad (msum, when, unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (MonadRandom)
import Data.Aeson (decodeStrict',FromJSON)
import Data.ByteString (ByteString)
import Data.Maybe (isNothing)
import qualified Data.ByteString.Char8 as BC
import qualified Jose.Internal.Base64 as B64
import qualified Jose.Internal.Parser as P
import Jose.Types
import Jose.Jwk
import Jose.Jwa
import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe
encode :: MonadRandom m
=> [Jwk]
-> JwtEncoding
-> Payload
-> m (Either JwtError Jwt)
encode :: forall (m :: * -> *).
MonadRandom m =>
[Jwk] -> JwtEncoding -> Payload -> m (Either JwtError Jwt)
encode [Jwk]
jwks JwtEncoding
encoding Payload
msg = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case JwtEncoding
encoding of
JwsEncoding JwsAlg
None -> case Payload
msg of
Claims ByteString
p -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> Jwt
Jwt forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BC.intercalate ByteString
"." [ByteString
unsecuredHdr, forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
p]
Nested Jwt
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadClaims
JwsEncoding JwsAlg
a -> case forall a. (a -> Bool) -> [a] -> [a]
filter (JwsAlg -> Jwk -> Bool
canEncodeJws JwsAlg
a) [Jwk]
jwks of
[] -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWS algorithm")
(Jwk
k:[Jwk]
_) -> forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
Jws.jwkEncode JwsAlg
a Jwk
k Payload
msg)
JweEncoding JweAlg
a Enc
e -> case forall a. (a -> Bool) -> [a] -> [a]
filter (JweAlg -> Jwk -> Bool
canEncodeJwe JweAlg
a) [Jwk]
jwks of
[] -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWE algorithm")
(Jwk
k:[Jwk]
_) -> forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
Jwe.jwkEncode JweAlg
a Enc
e Jwk
k Payload
msg)
where
unsecuredHdr :: ByteString
unsecuredHdr = forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode (String -> ByteString
BC.pack String
"{\"alg\":\"none\"}")
decode :: MonadRandom m
=> [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode :: forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode [Jwk]
keySet Maybe JwtEncoding
encoding ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
DecodableJwt
decodableJwt <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
[Maybe JwtContent]
decodings <- case (DecodableJwt
decodableJwt, Maybe JwtEncoding
encoding) of
(P.Unsecured ByteString
p, Just (JwsEncoding JwsAlg
None)) -> forall (m :: * -> *) a. Monad m => a -> m a
return [forall a. a -> Maybe a
Just (ByteString -> JwtContent
Unsecured ByteString
p)]
(P.Unsecured ByteString
_, Maybe JwtEncoding
_) -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"JWT is unsecured but expected 'alg' was not 'none'")
(P.DecodableJws JwsHeader
hdr Payload
_ Sig
_ SigTarget
_, Maybe JwtEncoding
e) -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just (JwsAlg -> JwtEncoding
JwsEncoding (JwsHeader -> JwsAlg
jwsAlg JwsHeader
hdr))) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected 'alg' doesn't match JWS header")
[Jwk]
ks <- forall {m :: * -> *} {a}. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (JwsHeader -> Jwk -> Bool
canDecodeJws JwsHeader
hdr) [Jwk]
keySet
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws [Jwk]
ks
(P.DecodableJwe JweHeader
hdr EncryptedCEK
_ IV
_ Payload
_ Tag
_ AAD
_, Maybe JwtEncoding
e) -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just (JweAlg -> Enc -> JwtEncoding
JweEncoding (JweHeader -> JweAlg
jweAlg JweHeader
hdr) (JweHeader -> Enc
jweEnc JweHeader
hdr))) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected encoding doesn't match JWE header")
[Jwk]
ks <- forall {m :: * -> *} {a}. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (JweHeader -> Jwk -> Bool
canDecodeJwe JweHeader
hdr) [Jwk]
keySet
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe [Jwk]
ks
case forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [Maybe JwtContent]
decodings of
Maybe JwtContent
Nothing -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"None of the keys was able to decode the JWT"
Just JwtContent
jwtContent -> forall (m :: * -> *) a. Monad m => a -> m a
return JwtContent
jwtContent
where
decodeWithJws :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws Jwk
k = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing) (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jws -> JwtContent
Jws) forall a b. (a -> b) -> a -> b
$ case Jwk
k of
Ed25519PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
Ed25519PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
Ed448PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
Ed448PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
RsaPublicJwk PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode PublicKey
kPub ByteString
jwt
RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode (PrivateKey -> PublicKey
private_pub PrivateKey
kPr) ByteString
jwt
EcPublicJwk PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode PublicKey
kPub ByteString
jwt
EcPrivateJwk KeyPair
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode (KeyPair -> PublicKey
ECDSA.toPublicKey KeyPair
kPr) ByteString
jwt
SymmetricJwk ByteString
kb Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> ByteString -> ByteString -> Either JwtError Jws
Jws.hmacDecode ByteString
kb ByteString
jwt
UnsupportedJwk Object
_ -> forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Unsupported JWKs cannot be used")
decodeWithJwe :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe Jwk
k = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just) (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
Jwe.jwkDecode Jwk
k ByteString
jwt))
checkKeys :: [a] -> ExceptT JwtError m [a]
checkKeys [] = forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"No suitable key was found to decode the JWT"
checkKeys [a]
ks = forall (m :: * -> *) a. Monad m => a -> m a
return [a]
ks
decodeClaims :: (FromJSON a)
=> ByteString
-> Either JwtError (JwtHeader, a)
decodeClaims :: forall a.
FromJSON a =>
ByteString -> Either JwtError (JwtHeader, a)
decodeClaims ByteString
jwt = do
let components :: [ByteString]
components = Char -> ByteString -> [ByteString]
BC.split Char
'.' ByteString
jwt
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
components forall a. Eq a => a -> a -> Bool
/= Int
3) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Int -> JwtError
BadDots Int
2
JwtHeader
hdr <- forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode (forall a. [a] -> a
head [ByteString]
components) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError JwtHeader
parseHeader
a
claims <- forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode ((forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
tail) [ByteString]
components) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {b}. FromJSON b => ByteString -> Either JwtError b
parseClaims
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, a
claims)
where
parseClaims :: ByteString -> Either JwtError b
parseClaims ByteString
bs = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left JwtError
BadClaims) forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. FromJSON a => ByteString -> Maybe a
decodeStrict' ByteString
bs