module Servant.Auth.Server.Internal.JWT where

import           Control.Lens
import           Control.Monad.Except
import           Control.Monad.Reader
import qualified Crypto.JOSE          as Jose
import qualified Crypto.JWT           as Jose
import           Data.Aeson           (FromJSON, Result (..), ToJSON, fromJSON,
                                       toJSON)
import           Data.ByteArray       (constEq)
import qualified Data.ByteString      as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict  as HM
import           Data.Maybe           (fromMaybe)
import qualified Data.Text            as T
import           Data.Time            (UTCTime)
import           Network.Wai          (requestHeaders)

import Servant.Auth.JWT               (FromJWT(..), ToJWT(..))
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types


-- | A JWT @AuthCheck@. You likely won't need to use this directly unless you
-- are protecting a @Raw@ endpoint.
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck :: JWTSettings -> AuthCheck usr
jwtAuthCheck JWTSettings
jwtSettings = do
  Request
req <- AuthCheck Request
forall r (m :: * -> *). MonadReader r m => m r
ask
  ByteString
token <- AuthCheck ByteString
-> (ByteString -> AuthCheck ByteString)
-> Maybe ByteString
-> AuthCheck ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe AuthCheck ByteString
forall a. Monoid a => a
mempty ByteString -> AuthCheck ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> AuthCheck ByteString)
-> Maybe ByteString -> AuthCheck ByteString
forall a b. (a -> b) -> a -> b
$ do
    ByteString
authHdr <- HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
    let bearer :: ByteString
bearer = ByteString
"Bearer "
        (ByteString
mbearer, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bearer) ByteString
authHdr
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
mbearer ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
bearer)
    ByteString -> Maybe ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
  Maybe usr
verifiedJWT <- IO (Maybe usr) -> AuthCheck (Maybe usr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe usr) -> AuthCheck (Maybe usr))
-> IO (Maybe usr) -> AuthCheck (Maybe usr)
forall a b. (a -> b) -> a -> b
$ JWTSettings -> ByteString -> IO (Maybe usr)
forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtSettings ByteString
token
  case Maybe usr
verifiedJWT of
    Maybe usr
Nothing -> AuthCheck usr
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Just usr
v -> usr -> AuthCheck usr
forall (m :: * -> *) a. Monad m => a -> m a
return usr
v

-- | Creates a JWT containing the specified data. The data is stored in the
-- @dat@ claim. The 'Maybe UTCTime' argument indicates the time at which the
-- token expires.
makeJWT :: ToJWT a
  => a -> JWTSettings -> Maybe UTCTime -> IO (Either Jose.Error BSL.ByteString)
makeJWT :: a -> JWTSettings -> Maybe UTCTime -> IO (Either Error ByteString)
makeJWT a
v JWTSettings
cfg Maybe UTCTime
expiry = ExceptT Error IO ByteString -> IO (Either Error ByteString)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error IO ByteString -> IO (Either Error ByteString))
-> ExceptT Error IO ByteString -> IO (Either Error ByteString)
forall a b. (a -> b) -> a -> b
$ do
  Alg
bestAlg <- JWK -> ExceptT Error IO Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
Jose.bestJWSAlg (JWK -> ExceptT Error IO Alg) -> JWK -> ExceptT Error IO Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> JWK
signingKey JWTSettings
cfg
  let alg :: Alg
alg = Alg -> Maybe Alg -> Alg
forall a. a -> Maybe a -> a
fromMaybe Alg
bestAlg (Maybe Alg -> Alg) -> Maybe Alg -> Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> Maybe Alg
jwtAlg JWTSettings
cfg
  SignedJWT
ejwt <- JWK -> JWSHeader () -> ClaimsSet -> ExceptT Error IO SignedJWT
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
Jose.signClaims (JWTSettings -> JWK
signingKey JWTSettings
cfg)
                          (((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
Jose.newJWSHeader ((), Alg
alg))
                          (ClaimsSet -> ClaimsSet
addExp (ClaimsSet -> ClaimsSet) -> ClaimsSet -> ClaimsSet
forall a b. (a -> b) -> a -> b
$ a -> ClaimsSet
forall a. ToJWT a => a -> ClaimsSet
encodeJWT a
v)

  ByteString -> ExceptT Error IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ExceptT Error IO ByteString)
-> ByteString -> ExceptT Error IO ByteString
forall a b. (a -> b) -> a -> b
$ SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
Jose.encodeCompact SignedJWT
ejwt
  where
   addExp :: ClaimsSet -> ClaimsSet
addExp ClaimsSet
claims = case Maybe UTCTime
expiry of
     Maybe UTCTime
Nothing -> ClaimsSet
claims
     Just UTCTime
e  -> ClaimsSet
claims ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
Jose.claimExp ((Maybe NumericDate -> Identity (Maybe NumericDate))
 -> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
Jose.NumericDate UTCTime
e


verifyJWT :: FromJWT a => JWTSettings -> BS.ByteString -> IO (Maybe a)
verifyJWT :: JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtCfg ByteString
input = do
  Either JWTError ClaimsSet
verifiedJWT <- IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet))
-> IO (Either JWTError ClaimsSet) -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
    SignedJWT
unverifiedJWT <- ByteString -> ExceptT JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> ByteString
BSL.fromStrict ByteString
input)
    JWTValidationSettings
-> JWKSet -> SignedJWT -> ExceptT JWTError IO ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
Jose.verifyClaims
      (JWTSettings -> JWTValidationSettings
jwtSettingsToJwtValidationSettings JWTSettings
jwtCfg)
      (JWTSettings -> JWKSet
validationKeys JWTSettings
jwtCfg)
      SignedJWT
unverifiedJWT
  Maybe a -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> IO (Maybe a)) -> Maybe a -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
verifiedJWT of
    Left (JWTError
_ :: Jose.JWTError) -> Maybe a
forall a. Maybe a
Nothing
    Right ClaimsSet
v -> case ClaimsSet -> Either Text a
forall a. FromJWT a => ClaimsSet -> Either Text a
decodeJWT ClaimsSet
v of
      Left Text
_ -> Maybe a
forall a. Maybe a
Nothing
      Right a
v' -> a -> Maybe a
forall a. a -> Maybe a
Just a
v'