{-|
Module      : PostgresWebsockets.Claims
Description : Parse and validate JWT to open postgres-websockets channels.

This module provides the JWT claims validation. Since websockets and
listening connections in the database tend to be resource intensive
(not to mention stateful) we need claims authorizing a specific channel and
mode of operation.
-}
module PostgresWebsockets.Claims
  ( ConnectionInfo,validateClaims
  ) where

import Protolude hiding (toS)
import Protolude.Conv
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as JSON
import qualified Data.Aeson.Key as Key

type Claims = JSON.KeyMap JSON.Value
type ConnectionInfo = ([Text], Text, Claims)

{-| Given a secret, a token and a timestamp it validates the claims and returns
    either an error message or a triple containing channel, mode and claims KeyMap.
-}
validateClaims
  :: Maybe Text
  -> ByteString
  -> LByteString
  -> UTCTime
  -> IO (Either Text ConnectionInfo)
validateClaims :: Maybe Text
-> ByteString
-> ByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel ByteString
secret ByteString
jwtToken UTCTime
time = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
  JWTAttempt
cl  <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ UTCTime -> JWK -> ByteString -> IO JWTAttempt
jwtClaims UTCTime
time (ByteString -> JWK
parseJWK ByteString
secret) ByteString
jwtToken
  KeyMap Value
cl' <- case JWTAttempt
cl of
    JWTClaims  KeyMap Value
c          -> forall (f :: * -> *) a. Applicative f => a -> f a
pure KeyMap Value
c
    JWTInvalid JWTError
JWTExpired -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Token expired"
    JWTInvalid JWTError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text
"Error: " forall a. Semigroup a => a -> a -> a
<> forall a b. (Show a, StringConv String b) => a -> b
show JWTError
err
  [Text]
channels  <-  let chs :: Maybe [Text]
chs = Text -> KeyMap Value -> Maybe [Text]
claimAsJSONList Text
"channels" KeyMap Value
cl' in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
"channel" KeyMap Value
cl' of
    Just Text
c ->  case Maybe [Text]
chs of
      Just [Text]
cs ->  forall a. Eq a => [a] -> [a]
nub (Text
c forall a. a -> [a] -> [a]
: [Text]
cs)
      Maybe [Text]
Nothing ->  [Text
c]
    Maybe Text
Nothing -> forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Text]
chs
  Text
mode <-
    let md :: Maybe Text
md = Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
"mode" KeyMap Value
cl'
    in case Maybe Text
md of
          Just Text
m  -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
m
          Maybe Text
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Missing mode"
  [Text]
requestedAllowedChannels <- case (Maybe Text
requestChannel, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
channels) of
    (Just Text
rc, Int
0) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text
rc]
    (Just Text
rc, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
== Text
rc) [Text]
channels
    (Maybe Text
Nothing, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
channels
  [Text]
validChannels <- if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
requestedAllowedChannels then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"No allowed channels" else forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
requestedAllowedChannels
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text]
validChannels, Text
mode, KeyMap Value
cl')

 where
  claimAsJSON :: Text -> Claims -> Maybe Text
  claimAsJSON :: Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
name KeyMap Value
cl = case forall v. Key -> KeyMap v -> Maybe v
JSON.lookup (Text -> Key
Key.fromText Text
name) KeyMap Value
cl of
    Just (JSON.String Text
s) -> forall a. a -> Maybe a
Just Text
s
    Maybe Value
_ -> forall a. Maybe a
Nothing

  claimAsJSONList :: Text -> Claims -> Maybe [Text]
  claimAsJSONList :: Text -> KeyMap Value -> Maybe [Text]
claimAsJSONList Text
name KeyMap Value
cl = case forall v. Key -> KeyMap v -> Maybe v
JSON.lookup (Text -> Key
Key.fromText Text
name) KeyMap Value
cl of
    Just Value
channelsJson ->
      case forall a. FromJSON a => Value -> Result a
JSON.fromJSON Value
channelsJson :: JSON.Result [Text] of
        JSON.Success [Text]
channelsList -> forall a. a -> Maybe a
Just [Text]
channelsList
        Result [Text]
_ -> forall a. Maybe a
Nothing
    Maybe Value
Nothing -> forall a. Maybe a
Nothing

{-|
  Possible situations encountered with client JWTs
-}
data JWTAttempt = JWTInvalid JWTError
                | JWTClaims (JSON.KeyMap JSON.Value)
                deriving JWTAttempt -> JWTAttempt -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTAttempt -> JWTAttempt -> Bool
$c/= :: JWTAttempt -> JWTAttempt -> Bool
== :: JWTAttempt -> JWTAttempt -> Bool
$c== :: JWTAttempt -> JWTAttempt -> Bool
Eq

{-|
  Receives the JWT secret (from config) and a JWT and returns a map
  of JWT claims.
-}
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims :: UTCTime -> JWK -> ByteString -> IO JWTAttempt
jwtClaims UTCTime
_ JWK
_ ByteString
"" = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ KeyMap Value -> JWTAttempt
JWTClaims forall v. KeyMap v
JSON.empty
jwtClaims UTCTime
time JWK
jwk' ByteString
payload = do
  let config :: JWTValidationSettings
config = (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings (forall a b. a -> b -> a
const Bool
True)
  Either JWTError ClaimsSet
eJwt <- forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
    SignedJWT
jwt <- forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
decodeCompact ByteString
payload
    forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
verifyClaimsAt JWTValidationSettings
config JWK
jwk' UTCTime
time SignedJWT
jwt
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
eJwt of
    Left JWTError
e    -> JWTError -> JWTAttempt
JWTInvalid JWTError
e
    Right ClaimsSet
jwt -> KeyMap Value -> JWTAttempt
JWTClaims forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> KeyMap Value
claims2map forall a b. (a -> b) -> a -> b
$ ClaimsSet
jwt

{-|
  Internal helper used to turn JWT ClaimSet into something
  easier to work with
-}
claims2map :: ClaimsSet -> JSON.KeyMap JSON.Value
claims2map :: ClaimsSet -> KeyMap Value
claims2map = Value -> KeyMap Value
val2map forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> Value
JSON.toJSON
 where
  val2map :: Value -> KeyMap Value
val2map (JSON.Object KeyMap Value
o) = KeyMap Value
o
  val2map Value
_          = forall v. KeyMap v
JSON.empty

{-|
  Internal helper to generate HMAC-SHA256. When the jwt key in the
  config file is a simple string rather than a JWK object, we'll
  apply this function to it.
-}
hs256jwk :: ByteString -> JWK
hs256jwk :: ByteString -> JWK
hs256jwk ByteString
key =
  KeyMaterial -> JWK
fromKeyMaterial KeyMaterial
km
    forall a b. a -> (a -> b) -> b
& Lens' JWK (Maybe KeyUse)
jwkUse forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ KeyUse
Sig
    forall a b. a -> (a -> b) -> b
& Lens' JWK (Maybe JWKAlg)
jwkAlg forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Alg -> JWKAlg
JWSAlg Alg
HS256
 where
  km :: KeyMaterial
km = OctKeyParameters -> KeyMaterial
OctKeyMaterial (Base64Octets -> OctKeyParameters
OctKeyParameters (ByteString -> Base64Octets
JOSE.Types.Base64Octets ByteString
key))

parseJWK :: ByteString -> JWK
parseJWK :: ByteString -> JWK
parseJWK ByteString
str =
  forall a. a -> Maybe a -> a
fromMaybe (ByteString -> JWK
hs256jwk ByteString
str) (forall a. FromJSON a => ByteString -> Maybe a
JSON.decode (forall a b. StringConv a b => a -> b
toS ByteString
str) :: Maybe JWK)