{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Web.OIDC.Client.Tokens
( Tokens(..)
, IdTokenClaims(..)
, validateIdToken
)
where
import Control.Applicative ((<|>))
import Control.Exception (throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (FromJSON (parseJSON),
FromJSON, Value (Object),
eitherDecode, withObject,
(.:), (.:?))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Either (partitionEithers)
import Data.Monoid ((<>))
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import GHC.Generics (Generic)
import Jose.Jwt (IntDate, Jwt, JwtContent (Jwe, Jws, Unsecured))
import qualified Jose.Jwt as Jwt
import Prelude hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import Web.OIDC.Client.Settings (OIDC (..))
import Web.OIDC.Client.Types (OpenIdException (..))
data Tokens a = Tokens
{ Tokens a -> Text
accessToken :: Text
, Tokens a -> Text
tokenType :: Text
, Tokens a -> IdTokenClaims a
idToken :: IdTokenClaims a
, Tokens a -> Jwt
idTokenJwt :: Jwt
, Tokens a -> Maybe Integer
expiresIn :: Maybe Integer
, Tokens a -> Maybe Text
refreshToken :: Maybe Text
}
deriving (Int -> Tokens a -> ShowS
[Tokens a] -> ShowS
Tokens a -> String
(Int -> Tokens a -> ShowS)
-> (Tokens a -> String) -> ([Tokens a] -> ShowS) -> Show (Tokens a)
forall a. Show a => Int -> Tokens a -> ShowS
forall a. Show a => [Tokens a] -> ShowS
forall a. Show a => Tokens a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tokens a] -> ShowS
$cshowList :: forall a. Show a => [Tokens a] -> ShowS
show :: Tokens a -> String
$cshow :: forall a. Show a => Tokens a -> String
showsPrec :: Int -> Tokens a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tokens a -> ShowS
Show, Tokens a -> Tokens a -> Bool
(Tokens a -> Tokens a -> Bool)
-> (Tokens a -> Tokens a -> Bool) -> Eq (Tokens a)
forall a. Eq a => Tokens a -> Tokens a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tokens a -> Tokens a -> Bool
$c/= :: forall a. Eq a => Tokens a -> Tokens a -> Bool
== :: Tokens a -> Tokens a -> Bool
$c== :: forall a. Eq a => Tokens a -> Tokens a -> Bool
Eq)
data IdTokenClaims a = IdTokenClaims
{ IdTokenClaims a -> Text
iss :: !Text
, IdTokenClaims a -> Text
sub :: !Text
, IdTokenClaims a -> [Text]
aud :: ![Text]
, IdTokenClaims a -> IntDate
exp :: !IntDate
, IdTokenClaims a -> IntDate
iat :: !IntDate
, IdTokenClaims a -> Maybe ByteString
nonce :: !(Maybe ByteString)
, IdTokenClaims a -> a
otherClaims :: !a
}
deriving (Int -> IdTokenClaims a -> ShowS
[IdTokenClaims a] -> ShowS
IdTokenClaims a -> String
(Int -> IdTokenClaims a -> ShowS)
-> (IdTokenClaims a -> String)
-> ([IdTokenClaims a] -> ShowS)
-> Show (IdTokenClaims a)
forall a. Show a => Int -> IdTokenClaims a -> ShowS
forall a. Show a => [IdTokenClaims a] -> ShowS
forall a. Show a => IdTokenClaims a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IdTokenClaims a] -> ShowS
$cshowList :: forall a. Show a => [IdTokenClaims a] -> ShowS
show :: IdTokenClaims a -> String
$cshow :: forall a. Show a => IdTokenClaims a -> String
showsPrec :: Int -> IdTokenClaims a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> IdTokenClaims a -> ShowS
Show, IdTokenClaims a -> IdTokenClaims a -> Bool
(IdTokenClaims a -> IdTokenClaims a -> Bool)
-> (IdTokenClaims a -> IdTokenClaims a -> Bool)
-> Eq (IdTokenClaims a)
forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c/= :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
== :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c== :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
Eq, (forall x. IdTokenClaims a -> Rep (IdTokenClaims a) x)
-> (forall x. Rep (IdTokenClaims a) x -> IdTokenClaims a)
-> Generic (IdTokenClaims a)
forall x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall x. IdTokenClaims a -> Rep (IdTokenClaims a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
$cto :: forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
$cfrom :: forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
Generic)
instance FromJSON a => FromJSON (IdTokenClaims a) where
parseJSON :: Value -> Parser (IdTokenClaims a)
parseJSON = String
-> (Object -> Parser (IdTokenClaims a))
-> Value
-> Parser (IdTokenClaims a)
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"IdTokenClaims" ((Object -> Parser (IdTokenClaims a))
-> Value -> Parser (IdTokenClaims a))
-> (Object -> Parser (IdTokenClaims a))
-> Value
-> Parser (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ \Object
o ->
Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
forall a.
Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
IdTokenClaims
(Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a)
-> Parser Text
-> Parser
(Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"iss"
Parser
(Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a)
-> Parser Text
-> Parser
([Text]
-> IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"sub"
Parser
([Text]
-> IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser [Text]
-> Parser
(IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Object
o Object -> Text -> Parser [Text]
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"aud" Parser [Text] -> Parser [Text] -> Parser [Text]
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ((Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
:[]) (Text -> [Text]) -> Parser Text -> Parser [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"aud")))
Parser
(IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser IntDate
-> Parser (IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser IntDate
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"exp"
Parser (IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser IntDate
-> Parser (Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser IntDate
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"iat"
Parser (Maybe ByteString -> a -> IdTokenClaims a)
-> Parser (Maybe ByteString) -> Parser (a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Text -> ByteString) -> Maybe Text -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> ByteString
encodeUtf8 (Maybe Text -> Maybe ByteString)
-> Parser (Maybe Text) -> Parser (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Text -> Parser (Maybe Text)
forall a. FromJSON a => Object -> Text -> Parser (Maybe a)
.:? Text
"nonce")
Parser (a -> IdTokenClaims a)
-> Parser a -> Parser (IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value -> Parser a
forall a. FromJSON a => Value -> Parser a
parseJSON (Object -> Value
Object Object
o)
validateIdToken :: (MonadIO m, FromJSON a) => OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken :: OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt' = do
let jwks :: [Jwk]
jwks = Provider -> [Jwk]
P.jwkSet (Provider -> [Jwk]) -> (OIDC -> Provider) -> OIDC -> [Jwk]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider (OIDC -> [Jwk]) -> OIDC -> [Jwk]
forall a b. (a -> b) -> a -> b
$ OIDC
oidc
token :: ByteString
token = Jwt -> ByteString
Jwt.unJwt Jwt
jwt'
algs :: [JwsAlgJson]
algs = Configuration -> [JwsAlgJson]
P.idTokenSigningAlgValuesSupported
(Configuration -> [JwsAlgJson])
-> (Provider -> Configuration) -> Provider -> [JwsAlgJson]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration
(Provider -> [JwsAlgJson]) -> Provider -> [JwsAlgJson]
forall a b. (a -> b) -> a -> b
$ OIDC -> Provider
oidcProvider OIDC
oidc
Either JwtError JwtContent
decoded <-
[Either JwtError JwtContent] -> Either JwtError JwtContent
forall b. [Either JwtError b] -> Either JwtError b
selectDecodedResult
([Either JwtError JwtContent] -> Either JwtError JwtContent)
-> m [Either JwtError JwtContent] -> m (Either JwtError JwtContent)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (JwsAlgJson -> m (Either JwtError JwtContent))
-> [JwsAlgJson] -> m [Either JwtError JwtContent]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
([Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
forall (m :: * -> *).
MonadIO m =>
[Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token)
[JwsAlgJson]
algs
case Either JwtError JwtContent
decoded of
Right (Unsecured ByteString
payload) -> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> (OpenIdException -> IO (IdTokenClaims a))
-> OpenIdException
-> m (IdTokenClaims a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO (IdTokenClaims a)
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> m (IdTokenClaims a))
-> OpenIdException -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ ByteString -> OpenIdException
UnsecuredJwt ByteString
payload
Right (Jws (JwsHeader
_header, ByteString
payload)) -> ByteString -> m (IdTokenClaims a)
forall a (m :: * -> *).
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
Right (Jwe (JweHeader
_header, ByteString
payload)) -> ByteString -> m (IdTokenClaims a)
forall a (m :: * -> *).
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
Left JwtError
err -> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> (OpenIdException -> IO (IdTokenClaims a))
-> OpenIdException
-> m (IdTokenClaims a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO (IdTokenClaims a)
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> m (IdTokenClaims a))
-> OpenIdException -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ JwtError -> OpenIdException
JwtException JwtError
err
where
tryDecode :: [Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token = \case
P.JwsAlgJson JwsAlg
alg -> IO (Either JwtError JwtContent) -> m (Either JwtError JwtContent)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either JwtError JwtContent) -> m (Either JwtError JwtContent))
-> IO (Either JwtError JwtContent)
-> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> IO (Either JwtError JwtContent)
forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
Jwt.decode [Jwk]
jwks (JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JwtEncoding -> Maybe JwtEncoding)
-> JwtEncoding -> Maybe JwtEncoding
forall a b. (a -> b) -> a -> b
$ JwsAlg -> JwtEncoding
Jwt.JwsEncoding JwsAlg
alg) ByteString
token
P.Unsupported Text
alg -> Either JwtError JwtContent -> m (Either JwtError JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError JwtContent -> m (Either JwtError JwtContent))
-> Either JwtError JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError JwtContent
forall a b. a -> Either a b
Left (JwtError -> Either JwtError JwtContent)
-> JwtError -> Either JwtError JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.BadAlgorithm (Text
"Unsupported algorithm: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
alg)
selectDecodedResult :: [Either JwtError b] -> Either JwtError b
selectDecodedResult [Either JwtError b]
xs = case [Either JwtError b] -> ([JwtError], [b])
forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either JwtError b]
xs of
([JwtError]
_, b
k : [b]
_) -> b -> Either JwtError b
forall a b. b -> Either a b
Right b
k
(JwtError
e : [JwtError]
_, [b]
_) -> JwtError -> Either JwtError b
forall a b. a -> Either a b
Left JwtError
e
([], []) -> JwtError -> Either JwtError b
forall a b. a -> Either a b
Left (JwtError -> Either JwtError b) -> JwtError -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.KeyError Text
"No Keys available for decoding"
parsePayload :: ByteString -> m a
parsePayload ByteString
payload = case ByteString -> Either String a
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either String a) -> ByteString -> Either String a
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
payload of
Right a
x -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
Left String
err -> IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (Text -> IO a) -> Text -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO a
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> IO a)
-> (Text -> OpenIdException) -> Text -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err