{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Crypto.JWT
(
signClaims
, SignedJWT
, defaultJWTValidationSettings
, verifyClaims
, verifyClaimsAt
, HasAllowedSkew(..)
, HasAudiencePredicate(..)
, HasIssuerPredicate(..)
, HasCheckIssuedAt(..)
, JWTValidationSettings
, HasJWTValidationSettings(..)
, ClaimsSet
, claimAud
, claimExp
, claimIat
, claimIss
, claimJti
, claimNbf
, claimSub
, unregisteredClaims
, addClaim
, emptyClaimsSet
, validateClaimsSet
, JWTError(..)
, AsJWTError(..)
, Audience(..)
, StringOrURI
, stringOrUri
, string
, uri
, NumericDate(..)
, module Crypto.JOSE
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Time (MonadTime(..))
#if ! MIN_VERSION_monad_time(0,2,0)
import Control.Monad.Time.Instances ()
#endif
import Data.Foldable (traverse_)
import Data.Functor.Identity
import Data.Maybe
import qualified Data.String
import Control.Lens (
makeClassy, makeClassyPrisms, makePrisms,
Lens', _Just, over, preview, view,
Prism', prism', Cons, iso, AsEmpty)
import Control.Lens.Cons.Extras (recons)
import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (ReaderT, ask, runReaderT)
import Data.Aeson
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
import Network.URI (parseURI)
import Crypto.JOSE
import Crypto.JOSE.Types
data JWTError
= JWSError Error
| JWTClaimsSetDecodeError String
| JWTExpired
| JWTNotYetValid
| JWTNotInIssuer
| JWTNotInAudience
| JWTIssuedAtFuture
deriving (Eq, Show)
makeClassyPrisms ''JWTError
instance AsError JWTError where
_Error = _JWSError
data StringOrURI = Arbitrary T.Text | OrURI URI deriving (Eq, Show)
instance Data.String.IsString StringOrURI where
fromString = fromJust . preview stringOrUri
stringOrUri :: (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri = iso (view recons) (view recons) . prism' rev fwd
where
rev (Arbitrary s) = s
rev (OrURI x) = T.pack (show x)
fwd s
| T.any (== ':') s = OrURI <$> parseURI (T.unpack s)
| otherwise = pure (Arbitrary s)
string :: Prism' StringOrURI T.Text
string = prism' Arbitrary f where
f (Arbitrary s) = Just s
f _ = Nothing
uri :: Prism' StringOrURI URI
uri = prism' OrURI f where
f (OrURI s) = Just s
f _ = Nothing
instance FromJSON StringOrURI where
parseJSON = withText "StringOrURI"
(maybe (fail "failed to parse StringOrURI") pure . preview stringOrUri)
instance ToJSON StringOrURI where
toJSON (Arbitrary s) = toJSON s
toJSON (OrURI x) = toJSON $ show x
newtype NumericDate = NumericDate UTCTime deriving (Eq, Ord, Show)
makePrisms ''NumericDate
instance FromJSON NumericDate where
parseJSON = withScientific "NumericDate" $
pure . NumericDate . posixSecondsToUTCTime . fromRational . toRational
instance ToJSON NumericDate where
toJSON (NumericDate t)
= Number $ fromRational $ toRational $ utcTimeToPOSIXSeconds t
newtype Audience = Audience [StringOrURI] deriving (Eq, Show)
makePrisms ''Audience
instance FromJSON Audience where
parseJSON v = Audience <$> (parseJSON v <|> fmap (:[]) (parseJSON v))
instance ToJSON Audience where
toJSON (Audience [aud]) = toJSON aud
toJSON (Audience auds) = toJSON auds
data ClaimsSet = ClaimsSet
{ _claimIss :: Maybe StringOrURI
, _claimSub :: Maybe StringOrURI
, _claimAud :: Maybe Audience
, _claimExp :: Maybe NumericDate
, _claimNbf :: Maybe NumericDate
, _claimIat :: Maybe NumericDate
, _claimJti :: Maybe T.Text
, _unregisteredClaims :: M.HashMap T.Text Value
}
deriving (Eq, Show)
claimIss :: Lens' ClaimsSet (Maybe StringOrURI)
claimIss f h@ClaimsSet{ _claimIss = a} =
fmap (\a' -> h { _claimIss = a' }) (f a)
claimSub :: Lens' ClaimsSet (Maybe StringOrURI)
claimSub f h@ClaimsSet{ _claimSub = a} =
fmap (\a' -> h { _claimSub = a' }) (f a)
claimAud :: Lens' ClaimsSet (Maybe Audience)
claimAud f h@ClaimsSet{ _claimAud = a} =
fmap (\a' -> h { _claimAud = a' }) (f a)
claimExp :: Lens' ClaimsSet (Maybe NumericDate)
claimExp f h@ClaimsSet{ _claimExp = a} =
fmap (\a' -> h { _claimExp = a' }) (f a)
claimNbf :: Lens' ClaimsSet (Maybe NumericDate)
claimNbf f h@ClaimsSet{ _claimNbf = a} =
fmap (\a' -> h { _claimNbf = a' }) (f a)
claimIat :: Lens' ClaimsSet (Maybe NumericDate)
claimIat f h@ClaimsSet{ _claimIat = a} =
fmap (\a' -> h { _claimIat = a' }) (f a)
claimJti :: Lens' ClaimsSet (Maybe T.Text)
claimJti f h@ClaimsSet{ _claimJti = a} =
fmap (\a' -> h { _claimJti = a' }) (f a)
unregisteredClaims :: Lens' ClaimsSet (M.HashMap T.Text Value)
unregisteredClaims f h@ClaimsSet{ _unregisteredClaims = a} =
fmap (\a' -> h { _unregisteredClaims = a' }) (f a)
emptyClaimsSet :: ClaimsSet
emptyClaimsSet = ClaimsSet n n n n n n n M.empty where n = Nothing
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim k v = over unregisteredClaims (M.insert k v)
filterUnregistered :: M.HashMap T.Text Value -> M.HashMap T.Text Value
filterUnregistered = M.filterWithKey (\k _ -> k `notElem` registered) where
registered = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
instance FromJSON ClaimsSet where
parseJSON = withObject "JWT Claims Set" (\o -> ClaimsSet
<$> o .:? "iss"
<*> o .:? "sub"
<*> o .:? "aud"
<*> o .:? "exp"
<*> o .:? "nbf"
<*> o .:? "iat"
<*> o .:? "jti"
<*> pure (filterUnregistered o))
instance ToJSON ClaimsSet where
toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = object $ catMaybes [
fmap ("iss" .=) iss
, fmap ("sub" .=) sub
, fmap ("aud" .=) aud
, fmap ("exp" .=) exp'
, fmap ("nbf" .=) nbf
, fmap ("iat" .=) iat
, fmap ("jti" .=) jti
] ++ M.toList (filterUnregistered o)
data JWTValidationSettings = JWTValidationSettings
{ _jwtValidationSettingsValidationSettings :: ValidationSettings
, _jwtValidationSettingsAllowedSkew :: NominalDiffTime
, _jwtValidationSettingsCheckIssuedAt :: Bool
, _jwtValidationSettingsAudiencePredicate :: StringOrURI -> Bool
, _jwtValidationSettingsIssuerPredicate :: StringOrURI -> Bool
}
makeClassy ''JWTValidationSettings
instance HasJWTValidationSettings a => HasValidationSettings a where
validationSettings = jwtValidationSettingsValidationSettings
class HasAllowedSkew s where
allowedSkew :: Lens' s NominalDiffTime
class HasAudiencePredicate s where
audiencePredicate :: Lens' s (StringOrURI -> Bool)
class HasIssuerPredicate s where
issuerPredicate :: Lens' s (StringOrURI -> Bool)
class HasCheckIssuedAt s where
checkIssuedAt :: Lens' s Bool
instance HasJWTValidationSettings a => HasAllowedSkew a where
allowedSkew = jwtValidationSettingsAllowedSkew
instance HasJWTValidationSettings a => HasAudiencePredicate a where
audiencePredicate = jwtValidationSettingsAudiencePredicate
instance HasJWTValidationSettings a => HasIssuerPredicate a where
issuerPredicate = jwtValidationSettingsIssuerPredicate
instance HasJWTValidationSettings a => HasCheckIssuedAt a where
checkIssuedAt = jwtValidationSettingsCheckIssuedAt
defaultJWTValidationSettings :: (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings p = JWTValidationSettings
defaultValidationSettings
0
True
p
(const True)
validateClaimsSet
::
( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
, HasIssuerPredicate a
, HasCheckIssuedAt a
, AsJWTError e, MonadError e m
)
=> a
-> ClaimsSet
-> m ClaimsSet
validateClaimsSet conf claims =
traverse_ (($ claims) . ($ conf))
[ validateExpClaim
, validateIatClaim
, validateNbfClaim
, validateIssClaim
, validateAudClaim
]
*> pure claims
validateExpClaim
:: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateExpClaim conf =
traverse_ (\t -> do
now <- currentTime
unless (now < addUTCTime (abs (view allowedSkew conf)) (view _NumericDate t)) $
throwing_ _JWTExpired )
. preview (claimExp . _Just)
validateIatClaim
:: (MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateIatClaim conf =
traverse_ (\t -> do
now <- currentTime
when (view checkIssuedAt conf) $
when (view _NumericDate t > addUTCTime (abs (view allowedSkew conf)) now) $
throwing_ _JWTIssuedAtFuture )
. preview (claimIat . _Just)
validateNbfClaim
:: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
=> a
-> ClaimsSet
-> m ()
validateNbfClaim conf =
traverse_ (\t -> do
now <- currentTime
unless (now >= addUTCTime (negate (abs (view allowedSkew conf))) (view _NumericDate t)) $
throwing_ _JWTNotYetValid )
. preview (claimNbf . _Just)
validateAudClaim
:: (HasAudiencePredicate s, AsJWTError e, MonadError e m)
=> s
-> ClaimsSet
-> m ()
validateAudClaim conf =
traverse_
(\auds -> unless (or (view audiencePredicate conf <$> auds)) $
throwing_ _JWTNotInAudience )
. preview (claimAud . _Just . _Audience)
validateIssClaim
:: (HasIssuerPredicate s, AsJWTError e, MonadError e m)
=> s
-> ClaimsSet
-> m ()
validateIssClaim conf =
traverse_ (\iss ->
unless (view issuerPredicate conf iss) (throwing_ _JWTNotInIssuer) )
. preview (claimIss . _Just)
type SignedJWT = CompactJWS JWSHeader
newtype WrappedUTCTime = WrappedUTCTime { getUTCTime :: UTCTime }
instance Monad m => MonadTime (ReaderT WrappedUTCTime m) where
currentTime = getUTCTime <$> ask
verifyClaims
::
( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
, HasIssuerPredicate a
, HasCheckIssuedAt a
, HasValidationSettings a
, AsError e, AsJWTError e, MonadError e m
, VerificationKeyStore m (JWSHeader ()) ClaimsSet k
)
=> a
-> k
-> SignedJWT
-> m ClaimsSet
verifyClaims conf k jws =
verifyJWSWithPayload f conf k jws >>= validateClaimsSet conf
where
f = either (throwing _JWTClaimsSetDecodeError) pure . eitherDecode
verifyClaimsAt
::
( HasAllowedSkew a, HasAudiencePredicate a
, HasIssuerPredicate a
, HasCheckIssuedAt a
, HasValidationSettings a
, AsError e, AsJWTError e, MonadError e m
, VerificationKeyStore (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k
)
=> a
-> k
-> UTCTime
-> SignedJWT
-> m ClaimsSet
verifyClaimsAt a k t jwt = runReaderT (verifyClaims a k jwt) (WrappedUTCTime t)
signClaims
:: (MonadRandom m, MonadError e m, AsError e)
=> JWK
-> JWSHeader ()
-> ClaimsSet
-> m SignedJWT
signClaims k h c = signJWS (encode c) (Identity (h, k))