module Crypto.JWT
(
JWT(..)
, JWTCrypto(..)
, JWTError(..)
, AsJWTError(..)
, JWTValidationSettings
, defaultJWTValidationSettings
, HasJWTValidationSettings(..)
, HasAllowedSkew(..)
, HasAudiencePredicate(..)
, HasIssuerPredicate(..)
, HasCheckIssuedAt(..)
, createJWSJWT
, validateJWSJWT
, ClaimsSet(..)
, claimAud
, claimExp
, claimIat
, claimIss
, claimJti
, claimNbf
, claimSub
, unregisteredClaims
, addClaim
, emptyClaimsSet
, validateClaimsSet
, Audience(..)
, StringOrURI
, fromString
, fromURI
, getString
, getURI
, NumericDate(..)
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Time (MonadTime(..))
#if ! MIN_VERSION_monad_time(0,2,0)
import Control.Monad.Time.Instances ()
#endif
import Data.Maybe
import qualified Data.String
import Control.Lens (
makeClassy, makeClassyPrisms, makeLenses, makePrisms,
Lens', _Just, over, preview, review, view)
import Control.Monad.Except (MonadError(throwError))
import Data.Aeson
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
import Network.URI (parseURI)
import Crypto.JOSE
import Crypto.JOSE.Types
data JWTError
= JWSError Error
| JWTExpired
| JWTNotYetValid
| JWTNotInIssuer
| JWTNotInAudience
| JWTIssuedAtFuture
deriving (Eq, Show)
makeClassyPrisms ''JWTError
instance AsError JWTError where
_Error = _JWSError
data StringOrURI = Arbitrary T.Text | OrURI URI deriving (Eq, Show)
instance Data.String.IsString StringOrURI where
fromString = Arbitrary . T.pack
fromString :: T.Text -> StringOrURI
fromString s = maybe (Arbitrary s) OrURI $ parseURI $ T.unpack s
fromURI :: URI -> StringOrURI
fromURI = OrURI
getString :: StringOrURI -> Maybe T.Text
getString (Arbitrary a) = Just a
getString (OrURI _) = Nothing
getURI :: StringOrURI -> Maybe URI
getURI (Arbitrary _) = Nothing
getURI (OrURI a) = Just a
instance FromJSON StringOrURI where
parseJSON = withText "StringOrURI" (\s ->
if T.any (== ':') s
then OrURI <$> parseJSON (String s)
else pure $ Arbitrary s)
instance ToJSON StringOrURI where
toJSON (Arbitrary s) = toJSON s
toJSON (OrURI uri) = toJSON $ show uri
newtype NumericDate = NumericDate UTCTime deriving (Eq, Ord, Show)
makePrisms ''NumericDate
instance FromJSON NumericDate where
parseJSON = withScientific "NumericDate" $
pure . NumericDate . posixSecondsToUTCTime . fromRational . toRational
instance ToJSON NumericDate where
toJSON (NumericDate t)
= Number $ fromRational $ toRational $ utcTimeToPOSIXSeconds t
newtype Audience = Audience [StringOrURI] deriving (Eq, Show)
makePrisms ''Audience
instance FromJSON Audience where
parseJSON v = Audience <$> (parseJSON v <|> fmap (:[]) (parseJSON v))
instance ToJSON Audience where
toJSON (Audience [aud]) = toJSON aud
toJSON (Audience auds) = toJSON auds
data ClaimsSet = ClaimsSet
{ _claimIss :: Maybe StringOrURI
, _claimSub :: Maybe StringOrURI
, _claimAud :: Maybe Audience
, _claimExp :: Maybe NumericDate
, _claimNbf :: Maybe NumericDate
, _claimIat :: Maybe NumericDate
, _claimJti :: Maybe T.Text
, _unregisteredClaims :: M.HashMap T.Text Value
}
deriving (Eq, Show)
makeLenses ''ClaimsSet
emptyClaimsSet :: ClaimsSet
emptyClaimsSet = ClaimsSet n n n n n n n M.empty where n = Nothing
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim k v = over unregisteredClaims (M.insert k v)
filterUnregistered :: M.HashMap T.Text Value -> M.HashMap T.Text Value
filterUnregistered = M.filterWithKey (\k _ -> k `notElem` registered) where
registered = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
instance FromJSON ClaimsSet where
parseJSON = withObject "JWT Claims Set" (\o -> ClaimsSet
<$> o .:? "iss"
<*> o .:? "sub"
<*> o .:? "aud"
<*> o .:? "exp"
<*> o .:? "nbf"
<*> o .:? "iat"
<*> o .:? "jti"
<*> pure (filterUnregistered o))
instance ToJSON ClaimsSet where
toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = object $ catMaybes [
fmap ("iss" .=) iss
, fmap ("sub" .=) sub
, fmap ("aud" .=) aud
, fmap ("exp" .=) exp'
, fmap ("nbf" .=) nbf
, fmap ("iat" .=) iat
, fmap ("jti" .=) jti
] ++ M.toList (filterUnregistered o)
data JWTValidationSettings = JWTValidationSettings
{ _jwtValidationSettingsValidationSettings :: ValidationSettings
, _jwtValidationSettingsAllowedSkew :: NominalDiffTime
, _jwtValidationSettingsCheckIssuedAt :: Bool
, _jwtValidationSettingsAudiencePredicate :: StringOrURI -> Bool
, _jwtValidationSettingsIssuerPredicate :: StringOrURI -> Bool
}
makeClassy ''JWTValidationSettings
instance HasValidationSettings JWTValidationSettings where
validationSettings = jwtValidationSettingsValidationSettings
class HasAllowedSkew s where
allowedSkew :: Lens' s NominalDiffTime
class HasAudiencePredicate s where
audiencePredicate :: Lens' s (StringOrURI -> Bool)
class HasIssuerPredicate s where
issuerPredicate :: Lens' s (StringOrURI -> Bool)
class HasCheckIssuedAt s where
checkIssuedAt :: Lens' s Bool
instance HasJWTValidationSettings a => HasAllowedSkew a where
allowedSkew = jwtValidationSettingsAllowedSkew
instance HasJWTValidationSettings a => HasAudiencePredicate a where
audiencePredicate = jwtValidationSettingsAudiencePredicate
instance HasJWTValidationSettings a => HasIssuerPredicate a where
issuerPredicate = jwtValidationSettingsIssuerPredicate
instance HasJWTValidationSettings a => HasCheckIssuedAt a where
checkIssuedAt = jwtValidationSettingsCheckIssuedAt
defaultJWTValidationSettings :: JWTValidationSettings
defaultJWTValidationSettings = JWTValidationSettings
defaultValidationSettings
0
False
(const False)
(const True)
validateClaimsSet
::
( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
, HasIssuerPredicate a
, HasCheckIssuedAt a
, AsJWTError e, MonadError e m
)
=> a
-> ClaimsSet
-> m ()
validateClaimsSet conf claims =
sequence_
[ validateExpClaim conf claims
, validateIatClaim conf claims
, validateNbfClaim conf claims
, validateIssClaim conf claims
, validateAudClaim conf claims
]
validateExpClaim
:: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateExpClaim conf (ClaimsSet _ _ _ (Just e) _ _ _ _) = do
now <- currentTime
if now < addUTCTime (abs (view allowedSkew conf)) (view _NumericDate e)
then pure ()
else throwError (review _JWTExpired ())
validateExpClaim _ _ = pure ()
validateIatClaim
:: (MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateIatClaim conf (ClaimsSet _ _ _ _ _ (Just t) _ _) = do
now <- currentTime
when (view checkIssuedAt conf) $
when ((view _NumericDate t) > addUTCTime (abs (view allowedSkew conf)) now) $
throwError (review _JWTIssuedAtFuture ())
validateIatClaim _ _ = pure ()
validateNbfClaim
:: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateNbfClaim conf (ClaimsSet _ _ _ _ (Just n) _ _ _) = do
now <- currentTime
if now >= addUTCTime (negate (abs (view allowedSkew conf))) (view _NumericDate n)
then pure ()
else throwError (review _JWTNotYetValid ())
validateNbfClaim _ _ = pure ()
validateAudClaim
:: (HasAudiencePredicate s, AsJWTError e, MonadError e m)
=> s
-> ClaimsSet
-> m ()
validateAudClaim conf claims =
maybe
(pure ())
(\auds ->
if or (view audiencePredicate conf <$> auds)
then pure ()
else throwError (review _JWTNotInAudience ())
)
(preview (claimAud . _Just . _Audience) claims)
validateIssClaim
:: (HasIssuerPredicate s, AsJWTError e, MonadError e m)
=> s
-> ClaimsSet
-> m ()
validateIssClaim conf claims =
maybe
(pure ())
(\iss ->
if view issuerPredicate conf iss
then pure ()
else throwError (review _JWTNotInIssuer ())
)
(preview (claimIss . _Just) claims)
data JWTCrypto = JWTJWS (JWS JWSHeader) deriving (Eq, Show)
instance FromCompact JWTCrypto where
fromCompact = fmap JWTJWS . fromCompact
instance ToCompact JWTCrypto where
toCompact (JWTJWS jws) = toCompact jws
data JWT = JWT
{ jwtCrypto :: JWTCrypto
, jwtClaimsSet :: ClaimsSet
} deriving (Eq, Show)
instance FromCompact JWT where
fromCompact = fromCompact >=> toJWT where
toJWT (JWTJWS jws) = either
(throwError . review _CompactDecodeError)
(pure . JWT (JWTJWS jws))
(eitherDecode $ jwsPayload jws)
instance ToCompact JWT where
toCompact = toCompact . jwtCrypto
validateJWSJWT
::
( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
, HasIssuerPredicate a
, HasCheckIssuedAt a
, HasValidationSettings a
, AsError e, AsJWTError e, MonadError e m
)
=> a
-> JWK
-> JWT
-> m ()
validateJWSJWT conf k (JWT (JWTJWS jws) c) = do
verifyJWS conf k jws >> validateClaimsSet conf c
createJWSJWT
:: (MonadRandom m, MonadError e m, AsError e)
=> JWK
-> JWSHeader
-> ClaimsSet
-> m JWT
createJWSJWT k h c =
(\jws -> JWT (JWTJWS jws) c) <$> signJWS (JWS payload []) h k
where
payload = Base64Octets $ BSL.toStrict $ encode c