{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Internal.Parser
( parseJwt
, DecodableJwt (..)
, EncryptedCEK (..)
, Payload (..)
, IV (..)
, Tag (..)
, AAD (..)
, Sig (..)
, SigTarget (..)
)
where
import Data.Bifunctor (first)
import Data.Aeson (eitherDecodeStrict')
import Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as P
import qualified Data.Attoparsec.ByteString.Char8 as PC
import Data.ByteArray.Encoding (convertFromBase, Base(..))
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Jose.Jwa
import Jose.Types (JwtError(..), JwtHeader(..), JwsHeader(..), JweHeader(..))
data DecodableJwt
= Unsecured ByteString
| DecodableJws JwsHeader Payload Sig SigTarget
| DecodableJwe JweHeader EncryptedCEK IV Payload Tag AAD
data Tag
= Tag16 ByteString
| Tag24 ByteString
| Tag32 ByteString
data IV
= IV12 ByteString
| IV16 ByteString
newtype Sig = Sig ByteString
newtype SigTarget = SigTarget ByteString
newtype AAD = AAD ByteString
newtype Payload = Payload ByteString
newtype EncryptedCEK = EncryptedCEK ByteString
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt ByteString
bs = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a b. a -> b -> a
const JwtError
BadCrypto) forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Either String a
P.parseOnly Parser DecodableJwt
jwt ByteString
bs
jwt :: Parser DecodableJwt
jwt :: Parser DecodableJwt
jwt = do
(JwtHeader
hdr, ByteString
raw) <- Parser (JwtHeader, ByteString)
jwtHeader
case JwtHeader
hdr of
JwtHeader
UnsecuredH -> ByteString -> DecodableJwt
Unsecured forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk
JwsH JwsHeader
h -> do
ByteString
payloadB64 <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
ByteString
payload <- ByteString -> Parser ByteString
b64Decode ByteString
payloadB64
Sig
s <- JwsAlg -> Parser Sig
sig (JwsHeader -> JwsAlg
jwsAlg JwsHeader
h)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ JwsHeader -> Payload -> Sig -> SigTarget -> DecodableJwt
DecodableJws JwsHeader
h (ByteString -> Payload
Payload ByteString
payload) Sig
s (ByteString -> SigTarget
SigTarget ([ByteString] -> ByteString
B.concat [ByteString
raw, ByteString
".", ByteString
payloadB64]))
JweH JweHeader
h ->
JweHeader
-> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt
DecodableJwe
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Applicative f => a -> f a
pure JweHeader
h
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser EncryptedCEK
encryptedCEK
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser IV
iv (JweHeader -> Enc
jweEnc JweHeader
h)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Payload
encryptedPayload
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser Tag
authTag (JweHeader -> Enc
jweEnc JweHeader
h)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> AAD
AAD ByteString
raw)
sig :: JwsAlg -> Parser Sig
sig :: JwsAlg -> Parser Sig
sig JwsAlg
_ = do
ByteString
t <- Parser ByteString
P.takeByteString forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString
b64Decode
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Sig
Sig ByteString
t)
authTag :: Enc -> Parser Tag
authTag :: Enc -> Parser Tag
authTag Enc
e = do
ByteString
t <- Parser ByteString
P.takeByteString forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString
b64Decode
case Enc
e of
Enc
A128GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A192GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A256GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A128CBC_HS256 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A192CBC_HS384 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag24 ByteString
t
Enc
A256CBC_HS512 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag32 ByteString
t
where
badTag :: String
badTag = String
"invalid auth tag"
tag16 :: ByteString -> m Tag
tag16 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
16 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag16 ByteString
t)
tag24 :: ByteString -> m Tag
tag24 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
24 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag24 ByteString
t)
tag32 :: ByteString -> m Tag
tag32 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
32 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag32 ByteString
t)
iv :: Enc -> Parser IV
iv :: Enc -> Parser IV
iv Enc
e = do
ByteString
bs <- Parser ByteString
base64Chunk
case Enc
e of
Enc
A128GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
A192GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
A256GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
_ -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv16 ByteString
bs
where
iv12 :: ByteString -> m IV
iv12 ByteString
bs = if ByteString -> Int
B.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int
12 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV12 ByteString
bs)
iv16 :: ByteString -> m IV
iv16 ByteString
bs = if ByteString -> Int
B.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int
16 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV16 ByteString
bs)
encryptedCEK :: Parser EncryptedCEK
encryptedCEK :: Parser EncryptedCEK
encryptedCEK = ByteString -> EncryptedCEK
EncryptedCEK forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk
encryptedPayload :: Parser Payload
encryptedPayload :: Parser Payload
encryptedPayload = ByteString -> Payload
Payload forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk
jwtHeader :: P.Parser (JwtHeader, ByteString)
= do
ByteString
hdrB64 <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
ByteString
hdrBytes <- ByteString -> Parser ByteString
b64Decode ByteString
hdrB64 :: P.Parser ByteString
JwtHeader
hdr <- forall {m :: * -> *} {a}.
(MonadFail m, FromJSON a) =>
ByteString -> m a
parseHdr ByteString
hdrBytes
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, ByteString
hdrB64)
where
parseHdr :: ByteString -> m a
parseHdr ByteString
bs = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict' ByteString
bs)
base64Chunk :: P.Parser ByteString
base64Chunk :: Parser ByteString
base64Chunk = do
ByteString
bs <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
ByteString -> Parser ByteString
b64Decode ByteString
bs
b64Decode :: ByteString -> P.Parser ByteString
b64Decode :: ByteString -> Parser ByteString
b64Decode ByteString
bs = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid Base64")) forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64URLUnpadded ByteString
bs