module Crypto.JOSE.JWS.Internal where
import Control.Applicative ((<|>))
import Data.Foldable (toList)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Monoid ((<>))
import Control.Lens hiding ((.=))
import Control.Monad.Except (MonadError(throwError))
import Data.Aeson
import Data.Byteable
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as M
import Data.List.NonEmpty (NonEmpty)
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Crypto.JOSE.Compact
import Crypto.JOSE.Error
import qualified Crypto.JOSE.JWA.JWS as JWA.JWS
import Crypto.JOSE.JWK
import Crypto.JOSE.Header
import qualified Crypto.JOSE.Types as Types
import qualified Crypto.JOSE.Types.Internal as Types
jwsCritInvalidNames :: [T.Text]
jwsCritInvalidNames = [
"alg"
, "jku"
, "jwk"
, "x5u"
, "x5t"
, "x5t#S256"
, "x5c"
, "kid"
, "typ"
, "cty"
, "crit"
]
data JWSHeader = JWSHeader
{ _jwsHeaderAlg :: HeaderParam JWA.JWS.Alg
, _jwsHeaderJku :: Maybe (HeaderParam Types.URI)
, _jwsHeaderJwk :: Maybe (HeaderParam JWK)
, _jwsHeaderKid :: Maybe (HeaderParam String)
, _jwsHeaderX5u :: Maybe (HeaderParam Types.URI)
, _jwsHeaderX5c :: Maybe (HeaderParam (NonEmpty Types.Base64X509))
, _jwsHeaderX5t :: Maybe (HeaderParam Types.Base64SHA1)
, _jwsHeaderX5tS256 :: Maybe (HeaderParam Types.Base64SHA256)
, _jwsHeaderTyp :: Maybe (HeaderParam String)
, _jwsHeaderCty :: Maybe (HeaderParam String)
, _jwsHeaderCrit :: Maybe (NonEmpty T.Text)
}
deriving (Eq, Show)
makeClassy ''JWSHeader
newJWSHeader :: (Protection, JWA.JWS.Alg) -> JWSHeader
newJWSHeader alg = JWSHeader (uncurry HeaderParam alg) z z z z z z z z z z
where z = Nothing
data Signature a = Signature
{ _protectedRaw :: (Maybe T.Text)
, _header :: a
, _signature :: Types.Base64Octets
}
deriving (Show)
makeLenses ''Signature
instance (Eq a, HasParams a) => Eq (Signature a) where
Signature r h s == Signature r' h' s' =
h == h' && s == s' && f r r'
where
f Nothing Nothing = True
f (Just t) (Just t') = t == t'
f Nothing (Just t') = BSL.toStrict (protectedParamsEncoded h) == T.encodeUtf8 t'
f (Just t) Nothing = T.encodeUtf8 t == BSL.toStrict (protectedParamsEncoded h')
instance HasParams a => FromJSON (Signature a) where
parseJSON = withObject "signature" (\o -> Signature
<$> (Just <$> (o .: "protected" <|> pure ""))
<*> do
hpB64 <- o .:? "protected"
hp <- maybe
(pure Nothing)
(withText "base64url-encoded header params"
(Types.parseB64Url (maybe
(fail "protected header contains invalid JSON")
pure . decode . BSL.fromStrict)))
hpB64
hu <- o .:? "header"
parseParams hp hu
<*> o .: "signature"
)
instance HasParams a => ToJSON (Signature a) where
toJSON (Signature _ h sig) =
let
pro = case protectedParamsEncoded h of
"" -> id
bs -> ("protected" .= String (T.decodeUtf8 (BSL.toStrict bs)) :)
unp = case unprotectedParams h of
Nothing -> id
Just o -> ("header" .= o :)
in
object $ (pro . unp) [("signature" .= sig)]
instance HasParams JWSHeader where
parseParamsFor proxy hp hu = JWSHeader
<$> headerRequired "alg" hp hu
<*> headerOptional "jku" hp hu
<*> headerOptional "jwk" hp hu
<*> headerOptional "kid" hp hu
<*> headerOptional "x5u" hp hu
<*> headerOptional "x5t" hp hu
<*> headerOptional "x5t#S256" hp hu
<*> headerOptional "x5c" hp hu
<*> headerOptional "typ" hp hu
<*> headerOptional "cty" hp hu
<*> (headerOptionalProtected "crit" hp hu
>>= parseCrit jwsCritInvalidNames (extensions proxy)
(fromMaybe mempty hp <> fromMaybe mempty hu))
params (JWSHeader alg jku jwk kid x5u x5c x5t x5tS256 typ cty crit) =
catMaybes
[ Just (protection alg, "alg" .= param alg)
, fmap (\p -> (protection p, "jku" .= param p)) jku
, fmap (\p -> (protection p, "jwk" .= param p)) jwk
, fmap (\p -> (protection p, "kid" .= param p)) kid
, fmap (\p -> (protection p, "x5u" .= param p)) x5u
, fmap (\p -> (protection p, "x5c" .= param p)) x5c
, fmap (\p -> (protection p, "x5t" .= param p)) x5t
, fmap (\p -> (protection p, "x5t#S256" .= param p)) x5tS256
, fmap (\p -> (protection p, "typ" .= param p)) typ
, fmap (\p -> (protection p, "cty" .= param p)) cty
, fmap (\p -> (Protected, "crit" .= p)) crit
]
data JWS a = JWS Types.Base64Octets [Signature a]
deriving (Eq, Show)
instance HasParams a => FromJSON (JWS a) where
parseJSON v =
withObject "JWS JSON serialization" (\o -> JWS
<$> o .: "payload"
<*> o .: "signatures") v
<|> withObject "Flattened JWS JSON serialization" (\o ->
if M.member "signatures" o
then fail "\"signatures\" member MUST NOT be present"
else (\p s -> JWS p [s]) <$> o .: "payload" <*> parseJSON v) v
instance HasParams a => ToJSON (JWS a) where
toJSON (JWS p ss) = object ["payload" .= p, "signatures" .= ss]
newJWS :: BS.ByteString -> JWS a
newJWS msg = JWS (Types.Base64Octets msg) []
jwsPayload :: JWS a -> BSL.ByteString
jwsPayload (JWS (Types.Base64Octets s) _) = BSL.fromStrict s
signingInput
:: HasParams a
=> Either T.Text a
-> Types.Base64Octets
-> BS.ByteString
signingInput h p = BS.intercalate "."
[ either T.encodeUtf8 (BSL.toStrict . protectedParamsEncoded) h
, toBytes p
]
instance HasParams a => ToCompact (JWS a) where
toCompact (JWS p [Signature raw h sig]) =
case unprotectedParams h of
Nothing -> pure
[ BSL.fromStrict $ signingInput (maybe (Right h) Left raw) p
, BSL.fromStrict $ toBytes sig
]
Just _ -> throwError $ review _CompactEncodeError $
"cannot encode a compact JWS with unprotected headers"
toCompact (JWS _ sigs) = throwError $ review _CompactEncodeError $
"cannot compact serialize JWS with " ++ show (length sigs) ++ " sigs"
instance HasParams a => FromCompact (JWS a) where
fromCompact xs = case xs of
[h, p, s] -> do
(h', p', s') <- (,,) <$> t h <*> t p <*> t s
let o = object [ ("payload", p'), ("protected", h'), ("signature", s') ]
case fromJSON o of
Error e -> throwError (compactErr e)
Success a -> pure a
xs' -> throwError $ compactErr $ "expected 3 parts, got " ++ show (length xs')
where
compactErr = review _CompactDecodeError
t = either (throwError . compactErr . show) (pure . String)
. T.decodeUtf8' . BSL.toStrict
signJWS
:: (HasJWSHeader a, HasParams a, MonadRandom m, AsError e, MonadError e m)
=> JWS a
-> a
-> JWK
-> m (JWS a)
signJWS (JWS p sigs) h k =
(\sig -> JWS p (Signature Nothing h (Types.Base64Octets sig):sigs))
<$> sign (param (view jwsHeaderAlg h)) (k ^. jwkMaterial) (signingInput (Right h) p)
data ValidationPolicy
= AnyValidated
| AllValidated
deriving (Eq)
data ValidationSettings = ValidationSettings
{ _validationSettingsAlgorithms :: S.Set JWA.JWS.Alg
, _validationSettingsValidationPolicy :: ValidationPolicy
}
makeClassy ''ValidationSettings
class HasAlgorithms s where
algorithms :: Lens' s (S.Set JWA.JWS.Alg)
class HasValidationPolicy s where
validationPolicy :: Lens' s ValidationPolicy
instance HasValidationSettings a => HasAlgorithms a where
algorithms = validationSettingsAlgorithms
instance HasValidationSettings a => HasValidationPolicy a where
validationPolicy = validationSettingsValidationPolicy
defaultValidationSettings :: ValidationSettings
defaultValidationSettings = ValidationSettings
( S.fromList
[ JWA.JWS.HS256, JWA.JWS.HS384, JWA.JWS.HS512
, JWA.JWS.RS256, JWA.JWS.RS384, JWA.JWS.RS512
, JWA.JWS.ES256, JWA.JWS.ES384, JWA.JWS.ES512
, JWA.JWS.PS256, JWA.JWS.PS384, JWA.JWS.PS512
] )
AllValidated
verifyJWS
:: ( HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m
, HasJWSHeader h, HasParams h)
=> a
-> JWK
-> JWS h
-> m ()
verifyJWS conf k (JWS p sigs) =
let
algs :: S.Set JWA.JWS.Alg
algs = conf ^. algorithms
policy :: ValidationPolicy
policy = conf ^. validationPolicy
shouldValidateSig = (`elem` algs) . param . view (header . jwsHeaderAlg)
applyPolicy AnyValidated xs =
if or xs then pure () else throwError (review _JWSNoValidSignatures ())
applyPolicy AllValidated [] = throwError (review _JWSNoSignatures ())
applyPolicy AllValidated xs =
if and xs then pure () else throwError (review _JWSInvalidSignature ())
validate = (== Right True) . verifySig k p
in
applyPolicy policy $ map validate $ filter shouldValidateSig $ toList sigs
verifySig
:: (HasJWSHeader a, HasParams a)
=> JWK
-> Types.Base64Octets
-> Signature a
-> Either Error Bool
verifySig k m (Signature raw h (Types.Base64Octets s)) =
verify (param (view jwsHeaderAlg h)) (view jwkMaterial k) tbs s
where
tbs = signingInput (maybe (Right h) Left raw) m