{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Web.OIDC.Client.Tokens
( Tokens(..)
, IdTokenClaims(..)
, validateIdToken
)
where
import Control.Applicative ((<|>))
import Control.Exception (throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (FromJSON (parseJSON),
FromJSON, Value (Object),
eitherDecode, withObject,
(.:), (.:?))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Either (partitionEithers)
import Data.Monoid ((<>))
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import GHC.Generics (Generic)
import Jose.Jwt (IntDate, Jwt, JwtContent (Jwe, Jws, Unsecured))
import qualified Jose.Jwt as Jwt
import Prelude hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import Web.OIDC.Client.Settings (OIDC (..))
import Web.OIDC.Client.Types (OpenIdException (..))
data Tokens a = Tokens
{ forall a. Tokens a -> Text
accessToken :: Text
, forall a. Tokens a -> Text
tokenType :: Text
, forall a. Tokens a -> IdTokenClaims a
idToken :: IdTokenClaims a
, forall a. Tokens a -> Jwt
idTokenJwt :: Jwt
, forall a. Tokens a -> Maybe Integer
expiresIn :: Maybe Integer
, forall a. Tokens a -> Maybe Text
refreshToken :: Maybe Text
}
deriving (Int -> Tokens a -> ShowS
forall a. Show a => Int -> Tokens a -> ShowS
forall a. Show a => [Tokens a] -> ShowS
forall a. Show a => Tokens a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tokens a] -> ShowS
$cshowList :: forall a. Show a => [Tokens a] -> ShowS
show :: Tokens a -> String
$cshow :: forall a. Show a => Tokens a -> String
showsPrec :: Int -> Tokens a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tokens a -> ShowS
Show, Tokens a -> Tokens a -> Bool
forall a. Eq a => Tokens a -> Tokens a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tokens a -> Tokens a -> Bool
$c/= :: forall a. Eq a => Tokens a -> Tokens a -> Bool
== :: Tokens a -> Tokens a -> Bool
$c== :: forall a. Eq a => Tokens a -> Tokens a -> Bool
Eq)
data IdTokenClaims a = IdTokenClaims
{ forall a. IdTokenClaims a -> Text
iss :: !Text
, forall a. IdTokenClaims a -> Text
sub :: !Text
, forall a. IdTokenClaims a -> [Text]
aud :: ![Text]
, forall a. IdTokenClaims a -> IntDate
exp :: !IntDate
, forall a. IdTokenClaims a -> IntDate
iat :: !IntDate
, forall a. IdTokenClaims a -> Maybe ByteString
nonce :: !(Maybe ByteString)
, forall a. IdTokenClaims a -> a
otherClaims :: !a
}
deriving (Int -> IdTokenClaims a -> ShowS
forall a. Show a => Int -> IdTokenClaims a -> ShowS
forall a. Show a => [IdTokenClaims a] -> ShowS
forall a. Show a => IdTokenClaims a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IdTokenClaims a] -> ShowS
$cshowList :: forall a. Show a => [IdTokenClaims a] -> ShowS
show :: IdTokenClaims a -> String
$cshow :: forall a. Show a => IdTokenClaims a -> String
showsPrec :: Int -> IdTokenClaims a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> IdTokenClaims a -> ShowS
Show, IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c/= :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
== :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c== :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
Eq, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
$cto :: forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
$cfrom :: forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
Generic)
instance FromJSON a => FromJSON (IdTokenClaims a) where
parseJSON :: Value -> Parser (IdTokenClaims a)
parseJSON = forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"IdTokenClaims" forall a b. (a -> b) -> a -> b
$ \Object
o ->
forall a.
Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
IdTokenClaims
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"iss"
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"sub"
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"aud" forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ((forall a. a -> [a] -> [a]
:[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"aud")))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"exp"
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"iat"
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> ByteString
encodeUtf8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
.:? Key
"nonce")
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. FromJSON a => Value -> Parser a
parseJSON (Object -> Value
Object Object
o)
validateIdToken :: (MonadIO m, FromJSON a) => OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken :: forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt' = do
let jwks :: [Jwk]
jwks = Provider -> [Jwk]
P.jwkSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider forall a b. (a -> b) -> a -> b
$ OIDC
oidc
token :: ByteString
token = Jwt -> ByteString
Jwt.unJwt Jwt
jwt'
algs :: [JwsAlgJson]
algs = Configuration -> [JwsAlgJson]
P.idTokenSigningAlgValuesSupported
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration
forall a b. (a -> b) -> a -> b
$ OIDC -> Provider
oidcProvider OIDC
oidc
Either JwtError JwtContent
decoded <-
forall {b}. [Either JwtError b] -> Either JwtError b
selectDecodedResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
(forall {m :: * -> *}.
MonadIO m =>
[Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token)
[JwsAlgJson]
algs
case Either JwtError JwtContent
decoded of
Right (Unsecured ByteString
payload) -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ ByteString -> OpenIdException
UnsecuredJwt ByteString
payload
Right (Jws (JwsHeader
_header, ByteString
payload)) -> forall {a} {m :: * -> *}.
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
Right (Jwe (JweHeader
_header, ByteString
payload)) -> forall {a} {m :: * -> *}.
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
Left JwtError
err -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ JwtError -> OpenIdException
JwtException JwtError
err
where
tryDecode :: [Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token = \case
P.JwsAlgJson JwsAlg
alg -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
Jwt.decode [Jwk]
jwks (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ JwsAlg -> JwtEncoding
Jwt.JwsEncoding JwsAlg
alg) ByteString
token
P.Unsupported Text
alg -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.BadAlgorithm (Text
"Unsupported algorithm: " forall a. Semigroup a => a -> a -> a
<> Text
alg)
selectDecodedResult :: [Either JwtError b] -> Either JwtError b
selectDecodedResult [Either JwtError b]
xs = case forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either JwtError b]
xs of
([JwtError]
_, b
k : [b]
_) -> forall a b. b -> Either a b
Right b
k
(JwtError
e : [JwtError]
_, [b]
_) -> forall a b. a -> Either a b
Left JwtError
e
([], []) -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.KeyError Text
"No Keys available for decoding"
parsePayload :: ByteString -> m a
parsePayload ByteString
payload = case forall a. FromJSON a => ByteString -> Either String a
eitherDecode forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
payload of
Right a
x -> forall (m :: * -> *) a. Monad m => a -> m a
return a
x
Left String
err -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err