module PostgresWebsockets.Claims
( ConnectionInfo,validateClaims
) where
import Protolude hiding (toS)
import Protolude.Conv
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as JSON
import qualified Data.Aeson.Key as Key
type Claims = JSON.KeyMap JSON.Value
type ConnectionInfo = ([Text], Text, Claims)
validateClaims
:: Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims :: Maybe Text
-> ByteString
-> ByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel ByteString
secret ByteString
jwtToken UTCTime
time = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
JWTAttempt
cl <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ UTCTime -> JWK -> ByteString -> IO JWTAttempt
jwtClaims UTCTime
time (ByteString -> JWK
parseJWK ByteString
secret) ByteString
jwtToken
KeyMap Value
cl' <- case JWTAttempt
cl of
JWTClaims KeyMap Value
c -> forall (f :: * -> *) a. Applicative f => a -> f a
pure KeyMap Value
c
JWTInvalid JWTError
JWTExpired -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Token expired"
JWTInvalid JWTError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text
"Error: " forall a. Semigroup a => a -> a -> a
<> forall a b. (Show a, StringConv String b) => a -> b
show JWTError
err
[Text]
channels <- let chs :: Maybe [Text]
chs = Text -> KeyMap Value -> Maybe [Text]
claimAsJSONList Text
"channels" KeyMap Value
cl' in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
"channel" KeyMap Value
cl' of
Just Text
c -> case Maybe [Text]
chs of
Just [Text]
cs -> forall a. Eq a => [a] -> [a]
nub (Text
c forall a. a -> [a] -> [a]
: [Text]
cs)
Maybe [Text]
Nothing -> [Text
c]
Maybe Text
Nothing -> forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Text]
chs
Text
mode <-
let md :: Maybe Text
md = Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
"mode" KeyMap Value
cl'
in case Maybe Text
md of
Just Text
m -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
m
Maybe Text
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Missing mode"
[Text]
requestedAllowedChannels <- case (Maybe Text
requestChannel, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
channels) of
(Just Text
rc, Int
0) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text
rc]
(Just Text
rc, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
== Text
rc) [Text]
channels
(Maybe Text
Nothing, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
channels
[Text]
validChannels <- if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
requestedAllowedChannels then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"No allowed channels" else forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
requestedAllowedChannels
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text]
validChannels, Text
mode, KeyMap Value
cl')
where
claimAsJSON :: Text -> Claims -> Maybe Text
claimAsJSON :: Text -> KeyMap Value -> Maybe Text
claimAsJSON Text
name KeyMap Value
cl = case forall v. Key -> KeyMap v -> Maybe v
JSON.lookup (Text -> Key
Key.fromText Text
name) KeyMap Value
cl of
Just (JSON.String Text
s) -> forall a. a -> Maybe a
Just Text
s
Maybe Value
_ -> forall a. Maybe a
Nothing
claimAsJSONList :: Text -> Claims -> Maybe [Text]
claimAsJSONList :: Text -> KeyMap Value -> Maybe [Text]
claimAsJSONList Text
name KeyMap Value
cl = case forall v. Key -> KeyMap v -> Maybe v
JSON.lookup (Text -> Key
Key.fromText Text
name) KeyMap Value
cl of
Just Value
channelsJson ->
case forall a. FromJSON a => Value -> Result a
JSON.fromJSON Value
channelsJson :: JSON.Result [Text] of
JSON.Success [Text]
channelsList -> forall a. a -> Maybe a
Just [Text]
channelsList
Result [Text]
_ -> forall a. Maybe a
Nothing
Maybe Value
Nothing -> forall a. Maybe a
Nothing
data JWTAttempt = JWTInvalid JWTError
| JWTClaims (JSON.KeyMap JSON.Value)
deriving JWTAttempt -> JWTAttempt -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTAttempt -> JWTAttempt -> Bool
$c/= :: JWTAttempt -> JWTAttempt -> Bool
== :: JWTAttempt -> JWTAttempt -> Bool
$c== :: JWTAttempt -> JWTAttempt -> Bool
Eq
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims :: UTCTime -> JWK -> ByteString -> IO JWTAttempt
jwtClaims UTCTime
_ JWK
_ ByteString
"" = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ KeyMap Value -> JWTAttempt
JWTClaims forall v. KeyMap v
JSON.empty
jwtClaims UTCTime
time JWK
jwk' ByteString
payload = do
let config :: JWTValidationSettings
config = (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings (forall a b. a -> b -> a
const Bool
True)
Either JWTError ClaimsSet
eJwt <- forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
SignedJWT
jwt <- forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
decodeCompact ByteString
payload
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
HasCheckIssuedAt a, HasValidationSettings a, AsError e,
AsJWTError e, MonadError e m,
VerificationKeyStore
(ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
verifyClaimsAt JWTValidationSettings
config JWK
jwk' UTCTime
time SignedJWT
jwt
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
eJwt of
Left JWTError
e -> JWTError -> JWTAttempt
JWTInvalid JWTError
e
Right ClaimsSet
jwt -> KeyMap Value -> JWTAttempt
JWTClaims forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> KeyMap Value
claims2map forall a b. (a -> b) -> a -> b
$ ClaimsSet
jwt
claims2map :: ClaimsSet -> JSON.KeyMap JSON.Value
claims2map :: ClaimsSet -> KeyMap Value
claims2map = Value -> KeyMap Value
val2map forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> Value
JSON.toJSON
where
val2map :: Value -> KeyMap Value
val2map (JSON.Object KeyMap Value
o) = KeyMap Value
o
val2map Value
_ = forall v. KeyMap v
JSON.empty
hs256jwk :: ByteString -> JWK
hs256jwk :: ByteString -> JWK
hs256jwk ByteString
key =
KeyMaterial -> JWK
fromKeyMaterial KeyMaterial
km
forall a b. a -> (a -> b) -> b
& Lens' JWK (Maybe KeyUse)
jwkUse forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ KeyUse
Sig
forall a b. a -> (a -> b) -> b
& Lens' JWK (Maybe JWKAlg)
jwkAlg forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Alg -> JWKAlg
JWSAlg Alg
HS256
where
km :: KeyMaterial
km = OctKeyParameters -> KeyMaterial
OctKeyMaterial (Base64Octets -> OctKeyParameters
OctKeyParameters (ByteString -> Base64Octets
JOSE.Types.Base64Octets ByteString
key))
parseJWK :: ByteString -> JWK
parseJWK :: ByteString -> JWK
parseJWK ByteString
str =
forall a. a -> Maybe a -> a
fromMaybe (ByteString -> JWK
hs256jwk ByteString
str) (forall a. FromJSON a => ByteString -> Maybe a
JSON.decode (forall a b. StringConv a b => a -> b
toS ByteString
str) :: Maybe JWK)