{-|
Module      : PostgREST.Auth
Description : PostgREST authorization functions.

This module provides functions to deal with the JWT authorization (http://jwt.io).
It also can be used to define other authorization functions,
in the future Oauth, LDAP and similar integrations can be coded here.

Authentication should always be implemented in an external service.
In the test suite there is an example of simple login function that can be used for a
very simple authentication system inside the PostgreSQL database.
-}
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Auth
  ( containsRole
  , jwtClaims
  , JWTClaims
  ) where

import qualified Crypto.JWT          as JWT
import qualified Data.Aeson          as JSON
import qualified Data.HashMap.Strict as M
import qualified Data.Vector         as V

import Control.Lens            (set)
import Control.Monad.Except    (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.Time.Clock         (UTCTime)

import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error  (Error (..))

import Protolude


type JWTClaims = M.HashMap Text JSON.Value

-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- map of JWT claims.
jwtClaims :: Monad m =>
  AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims :: AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims AppConfig
_ LByteString
"" UTCTime
_ = JWTClaims -> ExceptT Error m JWTClaims
forall (m :: * -> *) a. Monad m => a -> m a
return JWTClaims
forall k v. HashMap k v
M.empty
jwtClaims AppConfig{Bool
Int
[(Text, Text)]
[ByteString]
[Text]
JSPath
Maybe Integer
Maybe FilePath
Maybe ByteString
Maybe Text
Maybe StringOrURI
Maybe JWKSet
Maybe QualifiedIdentifier
Text
FileMode
NonEmpty Text
NominalDiffTime
OpenAPIMode
LogLevel
configServerUnixSocketMode :: AppConfig -> FileMode
configServerUnixSocket :: AppConfig -> Maybe FilePath
configServerPort :: AppConfig -> Int
configServerHost :: AppConfig -> Text
configRawMediaTypes :: AppConfig -> [ByteString]
configOpenApiServerProxyUri :: AppConfig -> Maybe Text
configOpenApiMode :: AppConfig -> OpenAPIMode
configLogLevel :: AppConfig -> LogLevel
configJwtSecretIsBase64 :: AppConfig -> Bool
configJwtSecret :: AppConfig -> Maybe ByteString
configJwtRoleClaimKey :: AppConfig -> JSPath
configJwtAudience :: AppConfig -> Maybe StringOrURI
configJWKS :: AppConfig -> Maybe JWKSet
configFilePath :: AppConfig -> Maybe FilePath
configDbUri :: AppConfig -> Text
configDbTxRollbackAll :: AppConfig -> Bool
configDbTxAllowOverride :: AppConfig -> Bool
configDbConfig :: AppConfig -> Bool
configDbSchemas :: AppConfig -> NonEmpty Text
configDbRootSpec :: AppConfig -> Maybe QualifiedIdentifier
configDbPreparedStatements :: AppConfig -> Bool
configDbPreRequest :: AppConfig -> Maybe QualifiedIdentifier
configDbPoolTimeout :: AppConfig -> NominalDiffTime
configDbPoolSize :: AppConfig -> Int
configDbMaxRows :: AppConfig -> Maybe Integer
configDbExtraSearchPath :: AppConfig -> [Text]
configDbChannelEnabled :: AppConfig -> Bool
configDbChannel :: AppConfig -> Text
configDbAnonRole :: AppConfig -> Text
configAppSettings :: AppConfig -> [(Text, Text)]
configServerUnixSocketMode :: FileMode
configServerUnixSocket :: Maybe FilePath
configServerPort :: Int
configServerHost :: Text
configRawMediaTypes :: [ByteString]
configOpenApiServerProxyUri :: Maybe Text
configOpenApiMode :: OpenAPIMode
configLogLevel :: LogLevel
configJwtSecretIsBase64 :: Bool
configJwtSecret :: Maybe ByteString
configJwtRoleClaimKey :: JSPath
configJwtAudience :: Maybe StringOrURI
configJWKS :: Maybe JWKSet
configFilePath :: Maybe FilePath
configDbUri :: Text
configDbTxRollbackAll :: Bool
configDbTxAllowOverride :: Bool
configDbConfig :: Bool
configDbSchemas :: NonEmpty Text
configDbRootSpec :: Maybe QualifiedIdentifier
configDbPreparedStatements :: Bool
configDbPreRequest :: Maybe QualifiedIdentifier
configDbPoolTimeout :: NominalDiffTime
configDbPoolSize :: Int
configDbMaxRows :: Maybe Integer
configDbExtraSearchPath :: [Text]
configDbChannelEnabled :: Bool
configDbChannel :: Text
configDbAnonRole :: Text
configAppSettings :: [(Text, Text)]
..} LByteString
payload UTCTime
time = do
  JWKSet
secret <- Either Error JWKSet -> ExceptT Error m JWKSet
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either Error JWKSet -> ExceptT Error m JWKSet)
-> (Maybe JWKSet -> Either Error JWKSet)
-> Maybe JWKSet
-> ExceptT Error m JWKSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> Maybe JWKSet -> Either Error JWKSet
forall l r. l -> Maybe r -> Either l r
maybeToRight Error
JwtTokenMissing (Maybe JWKSet -> ExceptT Error m JWKSet)
-> Maybe JWKSet -> ExceptT Error m JWKSet
forall a b. (a -> b) -> a -> b
$ Maybe JWKSet
configJWKS
  Either JWTError ClaimsSet
eitherClaims <-
    m (Either JWTError ClaimsSet)
-> ExceptT Error m (Either JWTError ClaimsSet)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either JWTError ClaimsSet)
 -> ExceptT Error m (Either JWTError ClaimsSet))
-> (ExceptT JWTError m ClaimsSet -> m (Either JWTError ClaimsSet))
-> ExceptT JWTError m ClaimsSet
-> ExceptT Error m (Either JWTError ClaimsSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT JWTError m ClaimsSet -> m (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError m ClaimsSet
 -> ExceptT Error m (Either JWTError ClaimsSet))
-> ExceptT JWTError m ClaimsSet
-> ExceptT Error m (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$
      JWTValidationSettings
-> JWKSet -> UTCTime -> SignedJWT -> ExceptT JWTError m ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
JWT.verifyClaimsAt JWTValidationSettings
validation JWKSet
secret UTCTime
time (SignedJWT -> ExceptT JWTError m ClaimsSet)
-> ExceptT JWTError m SignedJWT -> ExceptT JWTError m ClaimsSet
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< LByteString -> ExceptT JWTError m SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
LByteString -> m a
JWT.decodeCompact LByteString
payload
  Either Error JWTClaims -> ExceptT Error m JWTClaims
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either Error JWTClaims -> ExceptT Error m JWTClaims)
-> (Either JWTError JWTClaims -> Either Error JWTClaims)
-> Either JWTError JWTClaims
-> ExceptT Error m JWTClaims
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JWTError -> Error)
-> Either JWTError JWTClaims -> Either Error JWTClaims
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft JWTError -> Error
jwtClaimsError (Either JWTError JWTClaims -> ExceptT Error m JWTClaims)
-> Either JWTError JWTClaims -> ExceptT Error m JWTClaims
forall a b. (a -> b) -> a -> b
$ JSPath -> ClaimsSet -> JWTClaims
claimsMap JSPath
configJwtRoleClaimKey (ClaimsSet -> JWTClaims)
-> Either JWTError ClaimsSet -> Either JWTError JWTClaims
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either JWTError ClaimsSet
eitherClaims
  where
    validation :: JWTValidationSettings
validation =
      (StringOrURI -> Bool) -> JWTValidationSettings
JWT.defaultJWTValidationSettings StringOrURI -> Bool
audienceCheck JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
-> NominalDiffTime
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
forall s. HasAllowedSkew s => Lens' s NominalDiffTime
JWT.allowedSkew NominalDiffTime
1

    audienceCheck :: JWT.StringOrURI -> Bool
    audienceCheck :: StringOrURI -> Bool
audienceCheck = (StringOrURI -> Bool)
-> (StringOrURI -> StringOrURI -> Bool)
-> Maybe StringOrURI
-> StringOrURI
-> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True) StringOrURI -> StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
(==) Maybe StringOrURI
configJwtAudience

    jwtClaimsError :: JWT.JWTError -> Error
    jwtClaimsError :: JWTError -> Error
jwtClaimsError JWTError
JWT.JWTExpired = Text -> Error
JwtTokenInvalid Text
"JWT expired"
    jwtClaimsError JWTError
e              = Text -> Error
JwtTokenInvalid (Text -> Error) -> Text -> Error
forall a b. (a -> b) -> a -> b
$ JWTError -> Text
forall a b. (Show a, ConvertText FilePath b) => a -> b
show JWTError
e

-- | Turn JWT ClaimSet into something easier to work with.
--
-- Also, here the jspath is applied to put the "role" in the map.
claimsMap :: JSPath -> JWT.ClaimsSet -> JWTClaims
claimsMap :: JSPath -> ClaimsSet -> JWTClaims
claimsMap JSPath
jspath ClaimsSet
claims =
  case ClaimsSet -> Value
forall a. ToJSON a => a -> Value
JSON.toJSON ClaimsSet
claims of
    val :: Value
val@(JSON.Object JWTClaims
o) ->
      Text -> JWTClaims -> JWTClaims
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
M.delete Text
"role" JWTClaims
o JWTClaims -> JWTClaims -> JWTClaims
forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
`M.union` Value -> JWTClaims
forall k. (Hashable k, IsString k) => Value -> HashMap k Value
role Value
val
    Value
_ ->
      JWTClaims
forall k v. HashMap k v
M.empty
  where
    role :: Value -> HashMap k Value
role Value
value =
      HashMap k Value
-> (Value -> HashMap k Value) -> Maybe Value -> HashMap k Value
forall b a. b -> (a -> b) -> Maybe a -> b
maybe HashMap k Value
forall k v. HashMap k v
M.empty (k -> Value -> HashMap k Value
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton k
"role") (Maybe Value -> HashMap k Value) -> Maybe Value -> HashMap k Value
forall a b. (a -> b) -> a -> b
$ Maybe Value -> JSPath -> Maybe Value
walkJSPath (Value -> Maybe Value
forall a. a -> Maybe a
Just Value
value) JSPath
jspath

    walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
    walkJSPath :: Maybe Value -> JSPath -> Maybe Value
walkJSPath Maybe Value
x                      []                = Maybe Value
x
    walkJSPath (Just (JSON.Object JWTClaims
o)) (JSPKey Text
key:JSPath
rest) = Maybe Value -> JSPath -> Maybe Value
walkJSPath (Text -> JWTClaims -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
key JWTClaims
o) JSPath
rest
    walkJSPath (Just (JSON.Array Array
ar)) (JSPIdx Int
idx:JSPath
rest) = Maybe Value -> JSPath -> Maybe Value
walkJSPath (Array
ar Array -> Int -> Maybe Value
forall a. Vector a -> Int -> Maybe a
V.!? Int
idx) JSPath
rest
    walkJSPath Maybe Value
_                      JSPath
_                 = Maybe Value
forall a. Maybe a
Nothing

-- | Whether a response from jwtClaims contains a role claim
containsRole :: JWTClaims -> Bool
containsRole :: JWTClaims -> Bool
containsRole = Text -> JWTClaims -> Bool
forall k a. (Eq k, Hashable k) => k -> HashMap k a -> Bool
M.member Text
"role"