{-# LANGUAGE RecordWildCards #-}
module PostgREST.Auth
( containsRole
, jwtClaims
, JWTClaims
) where
import qualified Crypto.JWT as JWT
import qualified Data.Aeson as JSON
import qualified Data.HashMap.Strict as M
import qualified Data.Vector as V
import Control.Lens (set)
import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.Time.Clock (UTCTime)
import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error (Error (..))
import Protolude
type JWTClaims = M.HashMap Text JSON.Value
jwtClaims :: Monad m =>
AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims :: AppConfig -> LByteString -> UTCTime -> ExceptT Error m JWTClaims
jwtClaims AppConfig
_ LByteString
"" UTCTime
_ = JWTClaims -> ExceptT Error m JWTClaims
forall (m :: * -> *) a. Monad m => a -> m a
return JWTClaims
forall k v. HashMap k v
M.empty
jwtClaims AppConfig{Bool
Int
[(Text, Text)]
[ByteString]
[Text]
JSPath
Maybe Integer
Maybe FilePath
Maybe ByteString
Maybe Text
Maybe StringOrURI
Maybe JWKSet
Maybe QualifiedIdentifier
Text
FileMode
NonEmpty Text
NominalDiffTime
OpenAPIMode
LogLevel
configServerUnixSocketMode :: AppConfig -> FileMode
configServerUnixSocket :: AppConfig -> Maybe FilePath
configServerPort :: AppConfig -> Int
configServerHost :: AppConfig -> Text
configRawMediaTypes :: AppConfig -> [ByteString]
configOpenApiServerProxyUri :: AppConfig -> Maybe Text
configOpenApiMode :: AppConfig -> OpenAPIMode
configLogLevel :: AppConfig -> LogLevel
configJwtSecretIsBase64 :: AppConfig -> Bool
configJwtSecret :: AppConfig -> Maybe ByteString
configJwtRoleClaimKey :: AppConfig -> JSPath
configJwtAudience :: AppConfig -> Maybe StringOrURI
configJWKS :: AppConfig -> Maybe JWKSet
configFilePath :: AppConfig -> Maybe FilePath
configDbUri :: AppConfig -> Text
configDbTxRollbackAll :: AppConfig -> Bool
configDbTxAllowOverride :: AppConfig -> Bool
configDbConfig :: AppConfig -> Bool
configDbSchemas :: AppConfig -> NonEmpty Text
configDbRootSpec :: AppConfig -> Maybe QualifiedIdentifier
configDbPreparedStatements :: AppConfig -> Bool
configDbPreRequest :: AppConfig -> Maybe QualifiedIdentifier
configDbPoolTimeout :: AppConfig -> NominalDiffTime
configDbPoolSize :: AppConfig -> Int
configDbMaxRows :: AppConfig -> Maybe Integer
configDbExtraSearchPath :: AppConfig -> [Text]
configDbChannelEnabled :: AppConfig -> Bool
configDbChannel :: AppConfig -> Text
configDbAnonRole :: AppConfig -> Text
configAppSettings :: AppConfig -> [(Text, Text)]
configServerUnixSocketMode :: FileMode
configServerUnixSocket :: Maybe FilePath
configServerPort :: Int
configServerHost :: Text
configRawMediaTypes :: [ByteString]
configOpenApiServerProxyUri :: Maybe Text
configOpenApiMode :: OpenAPIMode
configLogLevel :: LogLevel
configJwtSecretIsBase64 :: Bool
configJwtSecret :: Maybe ByteString
configJwtRoleClaimKey :: JSPath
configJwtAudience :: Maybe StringOrURI
configJWKS :: Maybe JWKSet
configFilePath :: Maybe FilePath
configDbUri :: Text
configDbTxRollbackAll :: Bool
configDbTxAllowOverride :: Bool
configDbConfig :: Bool
configDbSchemas :: NonEmpty Text
configDbRootSpec :: Maybe QualifiedIdentifier
configDbPreparedStatements :: Bool
configDbPreRequest :: Maybe QualifiedIdentifier
configDbPoolTimeout :: NominalDiffTime
configDbPoolSize :: Int
configDbMaxRows :: Maybe Integer
configDbExtraSearchPath :: [Text]
configDbChannelEnabled :: Bool
configDbChannel :: Text
configDbAnonRole :: Text
configAppSettings :: [(Text, Text)]
..} LByteString
payload UTCTime
time = do
JWKSet
secret <- Either Error JWKSet -> ExceptT Error m JWKSet
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either Error JWKSet -> ExceptT Error m JWKSet)
-> (Maybe JWKSet -> Either Error JWKSet)
-> Maybe JWKSet
-> ExceptT Error m JWKSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> Maybe JWKSet -> Either Error JWKSet
forall l r. l -> Maybe r -> Either l r
maybeToRight Error
JwtTokenMissing (Maybe JWKSet -> ExceptT Error m JWKSet)
-> Maybe JWKSet -> ExceptT Error m JWKSet
forall a b. (a -> b) -> a -> b
$ Maybe JWKSet
configJWKS
Either JWTError ClaimsSet
eitherClaims <-
m (Either JWTError ClaimsSet)
-> ExceptT Error m (Either JWTError ClaimsSet)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either JWTError ClaimsSet)
-> ExceptT Error m (Either JWTError ClaimsSet))
-> (ExceptT JWTError m ClaimsSet -> m (Either JWTError ClaimsSet))
-> ExceptT JWTError m ClaimsSet
-> ExceptT Error m (Either JWTError ClaimsSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT JWTError m ClaimsSet -> m (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError m ClaimsSet
-> ExceptT Error m (Either JWTError ClaimsSet))
-> ExceptT JWTError m ClaimsSet
-> ExceptT Error m (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$
JWTValidationSettings
-> JWKSet -> UTCTime -> SignedJWT -> ExceptT JWTError m ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
HasCheckIssuedAt a, HasValidationSettings a, AsError e,
AsJWTError e, MonadError e m,
VerificationKeyStore
(ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
JWT.verifyClaimsAt JWTValidationSettings
validation JWKSet
secret UTCTime
time (SignedJWT -> ExceptT JWTError m ClaimsSet)
-> ExceptT JWTError m SignedJWT -> ExceptT JWTError m ClaimsSet
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< LByteString -> ExceptT JWTError m SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
LByteString -> m a
JWT.decodeCompact LByteString
payload
Either Error JWTClaims -> ExceptT Error m JWTClaims
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either Error JWTClaims -> ExceptT Error m JWTClaims)
-> (Either JWTError JWTClaims -> Either Error JWTClaims)
-> Either JWTError JWTClaims
-> ExceptT Error m JWTClaims
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JWTError -> Error)
-> Either JWTError JWTClaims -> Either Error JWTClaims
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft JWTError -> Error
jwtClaimsError (Either JWTError JWTClaims -> ExceptT Error m JWTClaims)
-> Either JWTError JWTClaims -> ExceptT Error m JWTClaims
forall a b. (a -> b) -> a -> b
$ JSPath -> ClaimsSet -> JWTClaims
claimsMap JSPath
configJwtRoleClaimKey (ClaimsSet -> JWTClaims)
-> Either JWTError ClaimsSet -> Either JWTError JWTClaims
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either JWTError ClaimsSet
eitherClaims
where
validation :: JWTValidationSettings
validation =
(StringOrURI -> Bool) -> JWTValidationSettings
JWT.defaultJWTValidationSettings StringOrURI -> Bool
audienceCheck JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
JWTValidationSettings
JWTValidationSettings
NominalDiffTime
NominalDiffTime
-> NominalDiffTime
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
JWTValidationSettings
JWTValidationSettings
NominalDiffTime
NominalDiffTime
forall s. HasAllowedSkew s => Lens' s NominalDiffTime
JWT.allowedSkew NominalDiffTime
1
audienceCheck :: JWT.StringOrURI -> Bool
audienceCheck :: StringOrURI -> Bool
audienceCheck = (StringOrURI -> Bool)
-> (StringOrURI -> StringOrURI -> Bool)
-> Maybe StringOrURI
-> StringOrURI
-> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True) StringOrURI -> StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
(==) Maybe StringOrURI
configJwtAudience
jwtClaimsError :: JWT.JWTError -> Error
jwtClaimsError :: JWTError -> Error
jwtClaimsError JWTError
JWT.JWTExpired = Text -> Error
JwtTokenInvalid Text
"JWT expired"
jwtClaimsError JWTError
e = Text -> Error
JwtTokenInvalid (Text -> Error) -> Text -> Error
forall a b. (a -> b) -> a -> b
$ JWTError -> Text
forall a b. (Show a, ConvertText FilePath b) => a -> b
show JWTError
e
claimsMap :: JSPath -> JWT.ClaimsSet -> JWTClaims
claimsMap :: JSPath -> ClaimsSet -> JWTClaims
claimsMap JSPath
jspath ClaimsSet
claims =
case ClaimsSet -> Value
forall a. ToJSON a => a -> Value
JSON.toJSON ClaimsSet
claims of
val :: Value
val@(JSON.Object JWTClaims
o) ->
Text -> JWTClaims -> JWTClaims
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
M.delete Text
"role" JWTClaims
o JWTClaims -> JWTClaims -> JWTClaims
forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
`M.union` Value -> JWTClaims
forall k. (Hashable k, IsString k) => Value -> HashMap k Value
role Value
val
Value
_ ->
JWTClaims
forall k v. HashMap k v
M.empty
where
role :: Value -> HashMap k Value
role Value
value =
HashMap k Value
-> (Value -> HashMap k Value) -> Maybe Value -> HashMap k Value
forall b a. b -> (a -> b) -> Maybe a -> b
maybe HashMap k Value
forall k v. HashMap k v
M.empty (k -> Value -> HashMap k Value
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton k
"role") (Maybe Value -> HashMap k Value) -> Maybe Value -> HashMap k Value
forall a b. (a -> b) -> a -> b
$ Maybe Value -> JSPath -> Maybe Value
walkJSPath (Value -> Maybe Value
forall a. a -> Maybe a
Just Value
value) JSPath
jspath
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath :: Maybe Value -> JSPath -> Maybe Value
walkJSPath Maybe Value
x [] = Maybe Value
x
walkJSPath (Just (JSON.Object JWTClaims
o)) (JSPKey Text
key:JSPath
rest) = Maybe Value -> JSPath -> Maybe Value
walkJSPath (Text -> JWTClaims -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
key JWTClaims
o) JSPath
rest
walkJSPath (Just (JSON.Array Array
ar)) (JSPIdx Int
idx:JSPath
rest) = Maybe Value -> JSPath -> Maybe Value
walkJSPath (Array
ar Array -> Int -> Maybe Value
forall a. Vector a -> Int -> Maybe a
V.!? Int
idx) JSPath
rest
walkJSPath Maybe Value
_ JSPath
_ = Maybe Value
forall a. Maybe a
Nothing
containsRole :: JWTClaims -> Bool
containsRole :: JWTClaims -> Bool
containsRole = Text -> JWTClaims -> Bool
forall k a. (Eq k, Hashable k) => k -> HashMap k a -> Bool
M.member Text
"role"