{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}

{-|
    Module: Web.OIDC.Client.Tokens
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.Tokens
    ( Tokens(..)
    , IdTokenClaims(..)
    , validateIdToken
    )
where

import           Control.Applicative                ((<|>))
import           Control.Exception                  (throwIO)
import           Control.Monad.IO.Class             (MonadIO, liftIO)
import           Data.Aeson                         (FromJSON (parseJSON),
                                                     FromJSON, Value (Object),
                                                     eitherDecode, withObject,
                                                     (.:), (.:?))
import           Data.ByteString                    (ByteString)
import qualified Data.ByteString.Lazy.Char8         as BL
import           Data.Either                        (partitionEithers)
import           Data.Monoid                        ((<>))
import           Data.Text                          (Text, pack)
import           Data.Text.Encoding                 (encodeUtf8)
import           GHC.Generics                       (Generic)
import           Jose.Jwt                           (IntDate, Jwt, JwtContent (Jwe, Jws, Unsecured))
import qualified Jose.Jwt                           as Jwt
import           Prelude                            hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import           Web.OIDC.Client.Settings           (OIDC (..))
import           Web.OIDC.Client.Types              (OpenIdException (..))

data Tokens a = Tokens
    { Tokens a -> Text
accessToken  :: Text
    , Tokens a -> Text
tokenType    :: Text
    , Tokens a -> IdTokenClaims a
idToken      :: IdTokenClaims a
    , Tokens a -> Jwt
idTokenJwt   :: Jwt
    , Tokens a -> Maybe Integer
expiresIn    :: Maybe Integer
    , Tokens a -> Maybe Text
refreshToken :: Maybe Text
    }
  deriving (Int -> Tokens a -> ShowS
[Tokens a] -> ShowS
Tokens a -> String
(Int -> Tokens a -> ShowS)
-> (Tokens a -> String) -> ([Tokens a] -> ShowS) -> Show (Tokens a)
forall a. Show a => Int -> Tokens a -> ShowS
forall a. Show a => [Tokens a] -> ShowS
forall a. Show a => Tokens a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tokens a] -> ShowS
$cshowList :: forall a. Show a => [Tokens a] -> ShowS
show :: Tokens a -> String
$cshow :: forall a. Show a => Tokens a -> String
showsPrec :: Int -> Tokens a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tokens a -> ShowS
Show, Tokens a -> Tokens a -> Bool
(Tokens a -> Tokens a -> Bool)
-> (Tokens a -> Tokens a -> Bool) -> Eq (Tokens a)
forall a. Eq a => Tokens a -> Tokens a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tokens a -> Tokens a -> Bool
$c/= :: forall a. Eq a => Tokens a -> Tokens a -> Bool
== :: Tokens a -> Tokens a -> Bool
$c== :: forall a. Eq a => Tokens a -> Tokens a -> Bool
Eq)

-- | Claims required for an <https://openid.net/specs/openid-connect-core-1_0.html#IDToken ID Token>,
--   plus recommended claims (nonce) and other custom claims.
data IdTokenClaims a = IdTokenClaims
    { IdTokenClaims a -> Text
iss         :: !Text
    , IdTokenClaims a -> Text
sub         :: !Text
    , IdTokenClaims a -> [Text]
aud         :: ![Text]
    , IdTokenClaims a -> IntDate
exp         :: !IntDate
    , IdTokenClaims a -> IntDate
iat         :: !IntDate
    , IdTokenClaims a -> Maybe ByteString
nonce       :: !(Maybe ByteString)
    , IdTokenClaims a -> a
otherClaims :: !a
    }
  deriving (Int -> IdTokenClaims a -> ShowS
[IdTokenClaims a] -> ShowS
IdTokenClaims a -> String
(Int -> IdTokenClaims a -> ShowS)
-> (IdTokenClaims a -> String)
-> ([IdTokenClaims a] -> ShowS)
-> Show (IdTokenClaims a)
forall a. Show a => Int -> IdTokenClaims a -> ShowS
forall a. Show a => [IdTokenClaims a] -> ShowS
forall a. Show a => IdTokenClaims a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IdTokenClaims a] -> ShowS
$cshowList :: forall a. Show a => [IdTokenClaims a] -> ShowS
show :: IdTokenClaims a -> String
$cshow :: forall a. Show a => IdTokenClaims a -> String
showsPrec :: Int -> IdTokenClaims a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> IdTokenClaims a -> ShowS
Show, IdTokenClaims a -> IdTokenClaims a -> Bool
(IdTokenClaims a -> IdTokenClaims a -> Bool)
-> (IdTokenClaims a -> IdTokenClaims a -> Bool)
-> Eq (IdTokenClaims a)
forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c/= :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
== :: IdTokenClaims a -> IdTokenClaims a -> Bool
$c== :: forall a. Eq a => IdTokenClaims a -> IdTokenClaims a -> Bool
Eq, (forall x. IdTokenClaims a -> Rep (IdTokenClaims a) x)
-> (forall x. Rep (IdTokenClaims a) x -> IdTokenClaims a)
-> Generic (IdTokenClaims a)
forall x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall x. IdTokenClaims a -> Rep (IdTokenClaims a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
$cto :: forall a x. Rep (IdTokenClaims a) x -> IdTokenClaims a
$cfrom :: forall a x. IdTokenClaims a -> Rep (IdTokenClaims a) x
Generic)

instance FromJSON a => FromJSON (IdTokenClaims a) where
    parseJSON :: Value -> Parser (IdTokenClaims a)
parseJSON = String
-> (Object -> Parser (IdTokenClaims a))
-> Value
-> Parser (IdTokenClaims a)
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"IdTokenClaims" ((Object -> Parser (IdTokenClaims a))
 -> Value -> Parser (IdTokenClaims a))
-> (Object -> Parser (IdTokenClaims a))
-> Value
-> Parser (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ \Object
o ->
        Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
forall a.
Text
-> Text
-> [Text]
-> IntDate
-> IntDate
-> Maybe ByteString
-> a
-> IdTokenClaims a
IdTokenClaims
            (Text
 -> Text
 -> [Text]
 -> IntDate
 -> IntDate
 -> Maybe ByteString
 -> a
 -> IdTokenClaims a)
-> Parser Text
-> Parser
     (Text
      -> [Text]
      -> IntDate
      -> IntDate
      -> Maybe ByteString
      -> a
      -> IdTokenClaims a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"iss"
            Parser
  (Text
   -> [Text]
   -> IntDate
   -> IntDate
   -> Maybe ByteString
   -> a
   -> IdTokenClaims a)
-> Parser Text
-> Parser
     ([Text]
      -> IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"sub"
            Parser
  ([Text]
   -> IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser [Text]
-> Parser
     (IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Object
o Object -> Text -> Parser [Text]
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"aud" Parser [Text] -> Parser [Text] -> Parser [Text]
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ((Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
:[]) (Text -> [Text]) -> Parser Text -> Parser [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Object
o Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"aud")))
            Parser
  (IntDate -> IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser IntDate
-> Parser (IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser IntDate
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"exp"
            Parser (IntDate -> Maybe ByteString -> a -> IdTokenClaims a)
-> Parser IntDate
-> Parser (Maybe ByteString -> a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Text -> Parser IntDate
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"iat"
            Parser (Maybe ByteString -> a -> IdTokenClaims a)
-> Parser (Maybe ByteString) -> Parser (a -> IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Text -> ByteString) -> Maybe Text -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> ByteString
encodeUtf8 (Maybe Text -> Maybe ByteString)
-> Parser (Maybe Text) -> Parser (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Text -> Parser (Maybe Text)
forall a. FromJSON a => Object -> Text -> Parser (Maybe a)
.:? Text
"nonce")
            Parser (a -> IdTokenClaims a)
-> Parser a -> Parser (IdTokenClaims a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value -> Parser a
forall a. FromJSON a => Value -> Parser a
parseJSON (Object -> Value
Object Object
o)

validateIdToken :: (MonadIO m, FromJSON a) => OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken :: OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt' = do
    let jwks :: [Jwk]
jwks  = Provider -> [Jwk]
P.jwkSet (Provider -> [Jwk]) -> (OIDC -> Provider) -> OIDC -> [Jwk]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider (OIDC -> [Jwk]) -> OIDC -> [Jwk]
forall a b. (a -> b) -> a -> b
$ OIDC
oidc
        token :: ByteString
token = Jwt -> ByteString
Jwt.unJwt Jwt
jwt'
        algs :: [JwsAlgJson]
algs  = Configuration -> [JwsAlgJson]
P.idTokenSigningAlgValuesSupported
              (Configuration -> [JwsAlgJson])
-> (Provider -> Configuration) -> Provider -> [JwsAlgJson]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration
              (Provider -> [JwsAlgJson]) -> Provider -> [JwsAlgJson]
forall a b. (a -> b) -> a -> b
$ OIDC -> Provider
oidcProvider OIDC
oidc
    Either JwtError JwtContent
decoded <-
        [Either JwtError JwtContent] -> Either JwtError JwtContent
forall b. [Either JwtError b] -> Either JwtError b
selectDecodedResult
            ([Either JwtError JwtContent] -> Either JwtError JwtContent)
-> m [Either JwtError JwtContent] -> m (Either JwtError JwtContent)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (JwsAlgJson -> m (Either JwtError JwtContent))
-> [JwsAlgJson] -> m [Either JwtError JwtContent]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
                    ([Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
forall (m :: * -> *).
MonadIO m =>
[Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token)
                    [JwsAlgJson]
algs
    case Either JwtError JwtContent
decoded of
        Right (Unsecured ByteString
payload) -> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> (OpenIdException -> IO (IdTokenClaims a))
-> OpenIdException
-> m (IdTokenClaims a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO (IdTokenClaims a)
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> m (IdTokenClaims a))
-> OpenIdException -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ ByteString -> OpenIdException
UnsecuredJwt ByteString
payload
        Right (Jws (JwsHeader
_header, ByteString
payload)) -> ByteString -> m (IdTokenClaims a)
forall a (m :: * -> *).
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
        Right (Jwe (JweHeader
_header, ByteString
payload)) -> ByteString -> m (IdTokenClaims a)
forall a (m :: * -> *).
(FromJSON a, MonadIO m) =>
ByteString -> m a
parsePayload ByteString
payload
        Left JwtError
err -> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> (OpenIdException -> IO (IdTokenClaims a))
-> OpenIdException
-> m (IdTokenClaims a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO (IdTokenClaims a)
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> m (IdTokenClaims a))
-> OpenIdException -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ JwtError -> OpenIdException
JwtException JwtError
err
  where
    tryDecode :: [Jwk] -> ByteString -> JwsAlgJson -> m (Either JwtError JwtContent)
tryDecode [Jwk]
jwks ByteString
token = \case
        P.JwsAlgJson  JwsAlg
alg -> IO (Either JwtError JwtContent) -> m (Either JwtError JwtContent)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either JwtError JwtContent) -> m (Either JwtError JwtContent))
-> IO (Either JwtError JwtContent)
-> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> IO (Either JwtError JwtContent)
forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
Jwt.decode [Jwk]
jwks (JwtEncoding -> Maybe JwtEncoding
forall a. a -> Maybe a
Just (JwtEncoding -> Maybe JwtEncoding)
-> JwtEncoding -> Maybe JwtEncoding
forall a b. (a -> b) -> a -> b
$ JwsAlg -> JwtEncoding
Jwt.JwsEncoding JwsAlg
alg) ByteString
token
        P.Unsupported Text
alg -> Either JwtError JwtContent -> m (Either JwtError JwtContent)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError JwtContent -> m (Either JwtError JwtContent))
-> Either JwtError JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError JwtContent
forall a b. a -> Either a b
Left (JwtError -> Either JwtError JwtContent)
-> JwtError -> Either JwtError JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.BadAlgorithm (Text
"Unsupported algorithm: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
alg)

    selectDecodedResult :: [Either JwtError b] -> Either JwtError b
selectDecodedResult [Either JwtError b]
xs = case [Either JwtError b] -> ([JwtError], [b])
forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either JwtError b]
xs of
        ([JwtError]
_, b
k : [b]
_) -> b -> Either JwtError b
forall a b. b -> Either a b
Right b
k
        (JwtError
e : [JwtError]
_, [b]
_) -> JwtError -> Either JwtError b
forall a b. a -> Either a b
Left JwtError
e
        ([], [])   -> JwtError -> Either JwtError b
forall a b. a -> Either a b
Left (JwtError -> Either JwtError b) -> JwtError -> Either JwtError b
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
Jwt.KeyError Text
"No Keys available for decoding"

    parsePayload :: ByteString -> m a
parsePayload ByteString
payload = case ByteString -> Either String a
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either String a) -> ByteString -> Either String a
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
payload of
        Right a
x   -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
        Left  String
err -> IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (Text -> IO a) -> Text -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenIdException -> IO a
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> IO a)
-> (Text -> OpenIdException) -> Text -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err