module Servant.Auth.Server.Internal.JWT where
import Control.Lens
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE as Jose
import qualified Crypto.JWT as Jose
import Data.Aeson (FromJSON, Result (..), ToJSON, fromJSON,
toJSON)
import Data.ByteArray (constEq)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
import qualified Data.Text as T
import Data.Time (UTCTime)
import Network.Wai (requestHeaders)
import Servant.Auth.JWT (FromJWT(..), ToJWT(..))
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck :: JWTSettings -> AuthCheck usr
jwtAuthCheck JWTSettings
jwtSettings = do
Request
req <- AuthCheck Request
forall r (m :: * -> *). MonadReader r m => m r
ask
ByteString
token <- AuthCheck ByteString
-> (ByteString -> AuthCheck ByteString)
-> Maybe ByteString
-> AuthCheck ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe AuthCheck ByteString
forall a. Monoid a => a
mempty ByteString -> AuthCheck ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> AuthCheck ByteString)
-> Maybe ByteString -> AuthCheck ByteString
forall a b. (a -> b) -> a -> b
$ do
ByteString
authHdr <- HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
let bearer :: ByteString
bearer = ByteString
"Bearer "
(ByteString
mbearer, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bearer) ByteString
authHdr
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
mbearer ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
bearer)
ByteString -> Maybe ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
Maybe usr
verifiedJWT <- IO (Maybe usr) -> AuthCheck (Maybe usr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe usr) -> AuthCheck (Maybe usr))
-> IO (Maybe usr) -> AuthCheck (Maybe usr)
forall a b. (a -> b) -> a -> b
$ JWTSettings -> ByteString -> IO (Maybe usr)
forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtSettings ByteString
token
case Maybe usr
verifiedJWT of
Maybe usr
Nothing -> AuthCheck usr
forall (m :: * -> *) a. MonadPlus m => m a
mzero
Just usr
v -> usr -> AuthCheck usr
forall (m :: * -> *) a. Monad m => a -> m a
return usr
v
makeJWT :: ToJWT a
=> a -> JWTSettings -> Maybe UTCTime -> IO (Either Jose.Error BSL.ByteString)
makeJWT :: a -> JWTSettings -> Maybe UTCTime -> IO (Either Error ByteString)
makeJWT a
v JWTSettings
cfg Maybe UTCTime
expiry = ExceptT Error IO ByteString -> IO (Either Error ByteString)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error IO ByteString -> IO (Either Error ByteString))
-> ExceptT Error IO ByteString -> IO (Either Error ByteString)
forall a b. (a -> b) -> a -> b
$ do
Alg
bestAlg <- JWK -> ExceptT Error IO Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
Jose.bestJWSAlg (JWK -> ExceptT Error IO Alg) -> JWK -> ExceptT Error IO Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> JWK
signingKey JWTSettings
cfg
let alg :: Alg
alg = Alg -> Maybe Alg -> Alg
forall a. a -> Maybe a -> a
fromMaybe Alg
bestAlg (Maybe Alg -> Alg) -> Maybe Alg -> Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> Maybe Alg
jwtAlg JWTSettings
cfg
SignedJWT
ejwt <- JWK -> JWSHeader () -> ClaimsSet -> ExceptT Error IO SignedJWT
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
Jose.signClaims (JWTSettings -> JWK
signingKey JWTSettings
cfg)
(((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
Jose.newJWSHeader ((), Alg
alg))
(ClaimsSet -> ClaimsSet
addExp (ClaimsSet -> ClaimsSet) -> ClaimsSet -> ClaimsSet
forall a b. (a -> b) -> a -> b
$ a -> ClaimsSet
forall a. ToJWT a => a -> ClaimsSet
encodeJWT a
v)
ByteString -> ExceptT Error IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ExceptT Error IO ByteString)
-> ByteString -> ExceptT Error IO ByteString
forall a b. (a -> b) -> a -> b
$ SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
Jose.encodeCompact SignedJWT
ejwt
where
addExp :: ClaimsSet -> ClaimsSet
addExp ClaimsSet
claims = case Maybe UTCTime
expiry of
Maybe UTCTime
Nothing -> ClaimsSet
claims
Just UTCTime
e -> ClaimsSet
claims ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
Jose.claimExp ((Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
Jose.NumericDate UTCTime
e
verifyJWT :: FromJWT a => JWTSettings -> BS.ByteString -> IO (Maybe a)
verifyJWT :: JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtCfg ByteString
input = do
Either JWTError ClaimsSet
verifiedJWT <- IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet))
-> IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
SignedJWT
unverifiedJWT <- ByteString -> ExceptT JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> ByteString
BSL.fromStrict ByteString
input)
JWTValidationSettings
-> JWKSet -> SignedJWT -> ExceptT JWTError IO ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
AsError e, AsJWTError e, MonadError e m,
VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
Jose.verifyClaims
(JWTSettings -> JWTValidationSettings
jwtSettingsToJwtValidationSettings JWTSettings
jwtCfg)
(JWTSettings -> JWKSet
validationKeys JWTSettings
jwtCfg)
SignedJWT
unverifiedJWT
Maybe a -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> IO (Maybe a)) -> Maybe a -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
verifiedJWT of
Left (JWTError
_ :: Jose.JWTError) -> Maybe a
forall a. Maybe a
Nothing
Right ClaimsSet
v -> case ClaimsSet -> Either Text a
forall a. FromJWT a => ClaimsSet -> Either Text a
decodeJWT ClaimsSet
v of
Left Text
_ -> Maybe a
forall a. Maybe a
Nothing
Right a
v' -> a -> Maybe a
forall a. a -> Maybe a
Just a
v'