{-# LANGUAGE OverloadedStrings #-}
module Web.OIDC.Client.CodeFlow
(
getAuthenticationRequestUrl
, getValidTokens
, prepareAuthenticationRequestUrl
, requestTokens
, validateClaims
, getCurrentIntDate
) where
import Control.Monad (unless, when)
import Control.Monad.Catch (MonadCatch, MonadThrow,
catch, throwM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (FromJSON, eitherDecode)
import qualified Data.ByteString.Char8 as B
import Data.List (nub)
import Data.Maybe (isNothing)
import Data.Monoid ((<>))
import Data.Text (Text, pack, unpack)
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Time.Clock.POSIX (getPOSIXTime)
import qualified Jose.Jwt as Jwt
import Network.HTTP.Client (Manager, Request (..),
getUri, httpLbs,
responseBody,
setQueryString,
urlEncodedBody)
import Network.URI (URI)
import Prelude hiding (exp)
import qualified Web.OIDC.Client.Discovery.Provider as P
import Web.OIDC.Client.Internal (parseUrl)
import qualified Web.OIDC.Client.Internal as I
import Web.OIDC.Client.Settings (OIDC (..))
import Web.OIDC.Client.Tokens (IdTokenClaims (..), validateIdToken,
Tokens (..))
import Web.OIDC.Client.Types (Code, Nonce,
OpenIdException (..),
Parameters, Scope,
SessionStore (..), State,
openId)
prepareAuthenticationRequestUrl
:: (MonadThrow m, MonadCatch m)
=> SessionStore m
-> OIDC
-> Scope
-> Parameters
-> m URI
prepareAuthenticationRequestUrl :: forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
SessionStore m -> OIDC -> Scope -> Parameters -> m URI
prepareAuthenticationRequestUrl SessionStore m
store OIDC
oidc Scope
scope Parameters
params = do
ByteString
state <- forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
ByteString
nonce' <- forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
forall (m :: * -> *).
SessionStore m -> ByteString -> ByteString -> m ()
sessionStoreSave SessionStore m
store ByteString
state ByteString
nonce'
forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope (forall a. a -> Maybe a
Just ByteString
state) forall a b. (a -> b) -> a -> b
$ Parameters
params forall a. [a] -> [a] -> [a]
++ [(ByteString
"nonce", forall a. a -> Maybe a
Just ByteString
nonce')]
getValidTokens
:: (MonadThrow m, MonadCatch m, MonadIO m, FromJSON a)
=> SessionStore m
-> OIDC
-> Manager
-> State
-> Code
-> m (Tokens a)
getValidTokens :: forall (m :: * -> *) a.
(MonadThrow m, MonadCatch m, MonadIO m, FromJSON a) =>
SessionStore m
-> OIDC -> Manager -> ByteString -> ByteString -> m (Tokens a)
getValidTokens SessionStore m
store OIDC
oidc Manager
mgr ByteString
stateFromIdP ByteString
code = do
Maybe ByteString
savedNonce <- forall (m :: * -> *).
SessionStore m -> ByteString -> m (Maybe ByteString)
sessionStoreGet SessionStore m
store ByteString
stateFromIdP
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isNothing Maybe ByteString
savedNonce) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM OpenIdException
UnknownState
Tokens a
result <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> ByteString -> Manager -> IO (Tokens a)
requestTokens OIDC
oidc Maybe ByteString
savedNonce ByteString
code Manager
mgr
forall (m :: * -> *). SessionStore m -> m ()
sessionStoreDelete SessionStore m
store
forall (m :: * -> *) a. Monad m => a -> m a
return Tokens a
result
{-# WARNING getAuthenticationRequestUrl "This function doesn't manage state and nonce. Use prepareAuthenticationRequestUrl only unless your IdP doesn't support state and/or nonce." #-}
getAuthenticationRequestUrl
:: (MonadThrow m, MonadCatch m)
=> OIDC
-> Scope
-> Maybe State
-> Parameters
-> m URI
getAuthenticationRequestUrl :: forall (m :: * -> *).
(MonadThrow m, MonadCatch m) =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope Maybe ByteString
state Parameters
params = do
Request
req <- forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Request -> URI
getUri forall a b. (a -> b) -> a -> b
$ Parameters -> Request -> Request
setQueryString Parameters
query Request
req
where
endpoint :: Text
endpoint = OIDC -> Text
oidcAuthorizationServerUrl OIDC
oidc
query :: Parameters
query = Parameters
requireds forall a. [a] -> [a] -> [a]
++ Parameters
state' forall a. [a] -> [a] -> [a]
++ Parameters
params
requireds :: Parameters
requireds =
[ (ByteString
"response_type", forall a. a -> Maybe a
Just ByteString
"code")
, (ByteString
"client_id", forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcClientId OIDC
oidc)
, (ByteString
"redirect_uri", forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcRedirectUri OIDC
oidc)
, (ByteString
"scope", forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unwords forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [a]
nub forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Text -> String
unpack forall a b. (a -> b) -> a -> b
$ Text
openIdforall a. a -> [a] -> [a]
:Scope
scope)
]
state' :: Parameters
state' =
case Maybe ByteString
state of
Just ByteString
_ -> [(ByteString
"state", Maybe ByteString
state)]
Maybe ByteString
Nothing -> []
{-# WARNING requestTokens "This function doesn't manage state and nonce. Use getValidTokens only unless your IdP doesn't support state and/or nonce." #-}
requestTokens :: FromJSON a => OIDC -> Maybe Nonce -> Code -> Manager -> IO (Tokens a)
requestTokens :: forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> ByteString -> Manager -> IO (Tokens a)
requestTokens OIDC
oidc Maybe ByteString
savedNonce ByteString
code Manager
manager = do
ByteString
json <- IO ByteString
getTokensJson forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
case forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
json of
Right TokensResponse
ts -> forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> TokensResponse -> IO (Tokens a)
validate OIDC
oidc Maybe ByteString
savedNonce TokensResponse
ts
Left String
err -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> OpenIdException
JsonException forall a b. (a -> b) -> a -> b
$ String -> Text
pack String
err
where
getTokensJson :: IO ByteString
getTokensJson = do
Request
req <- forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint
let req' :: Request
req' = [(ByteString, ByteString)] -> Request -> Request
urlEncodedBody [(ByteString, ByteString)]
body forall a b. (a -> b) -> a -> b
$ Request
req { method :: ByteString
method = ByteString
"POST" }
Response ByteString
res <- Request -> Manager -> IO (Response ByteString)
httpLbs Request
req' Manager
manager
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall body. Response body -> body
responseBody Response ByteString
res
endpoint :: Text
endpoint = OIDC -> Text
oidcTokenEndpoint OIDC
oidc
cid :: ByteString
cid = OIDC -> ByteString
oidcClientId OIDC
oidc
sec :: ByteString
sec = OIDC -> ByteString
oidcClientSecret OIDC
oidc
redirect :: ByteString
redirect = OIDC -> ByteString
oidcRedirectUri OIDC
oidc
body :: [(ByteString, ByteString)]
body =
[ (ByteString
"grant_type", ByteString
"authorization_code")
, (ByteString
"code", ByteString
code)
, (ByteString
"client_id", ByteString
cid)
, (ByteString
"client_secret", ByteString
sec)
, (ByteString
"redirect_uri", ByteString
redirect)
]
validate :: FromJSON a => OIDC -> Maybe Nonce -> I.TokensResponse -> IO (Tokens a)
validate :: forall a.
FromJSON a =>
OIDC -> Maybe ByteString -> TokensResponse -> IO (Tokens a)
validate OIDC
oidc Maybe ByteString
savedNonce TokensResponse
tres = do
let jwt' :: Jwt
jwt' = TokensResponse -> Jwt
I.idToken TokensResponse
tres
IdTokenClaims a
claims' <- forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt'
IntDate
now <- IO IntDate
getCurrentIntDate
forall a.
Text
-> Text -> IntDate -> Maybe ByteString -> IdTokenClaims a -> IO ()
validateClaims
(Configuration -> Text
P.issuer forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provider -> Configuration
P.configuration forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> Provider
oidcProvider forall a b. (a -> b) -> a -> b
$ OIDC
oidc)
(OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode forall b c a. (b -> c) -> (a -> b) -> a -> c
. OIDC -> ByteString
oidcClientId forall a b. (a -> b) -> a -> b
$ OIDC
oidc)
IntDate
now
Maybe ByteString
savedNonce
IdTokenClaims a
claims'
forall (m :: * -> *) a. Monad m => a -> m a
return Tokens {
accessToken :: Text
accessToken = TokensResponse -> Text
I.accessToken TokensResponse
tres
, tokenType :: Text
tokenType = TokensResponse -> Text
I.tokenType TokensResponse
tres
, idToken :: IdTokenClaims a
idToken = IdTokenClaims a
claims'
, idTokenJwt :: Jwt
idTokenJwt = Jwt
jwt'
, expiresIn :: Maybe Integer
expiresIn = TokensResponse -> Maybe Integer
I.expiresIn TokensResponse
tres
, refreshToken :: Maybe Text
refreshToken = TokensResponse -> Maybe Text
I.refreshToken TokensResponse
tres
}
validateClaims :: Text -> Text -> Jwt.IntDate -> Maybe Nonce -> IdTokenClaims a -> IO ()
validateClaims :: forall a.
Text
-> Text -> IntDate -> Maybe ByteString -> IdTokenClaims a -> IO ()
validateClaims Text
issuer' Text
clientId' IntDate
now Maybe ByteString
savedNonce IdTokenClaims a
claims' = do
let iss' :: Text
iss' = forall a. IdTokenClaims a -> Text
iss IdTokenClaims a
claims'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text
iss' forall a. Eq a => a -> a -> Bool
== Text
issuer')
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException forall a b. (a -> b) -> a -> b
$ Text
"issuer from token \"" forall a. Semigroup a => a -> a -> a
<> Text
iss' forall a. Semigroup a => a -> a -> a
<> Text
"\" is different than expected issuer \"" forall a. Semigroup a => a -> a -> a
<> Text
issuer' forall a. Semigroup a => a -> a -> a
<> Text
"\""
let aud' :: Scope
aud' = forall a. IdTokenClaims a -> Scope
aud IdTokenClaims a
claims'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text
clientId' forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Scope
aud')
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException forall a b. (a -> b) -> a -> b
$ Text
"our client \"" forall a. Semigroup a => a -> a -> a
<> Text
clientId' forall a. Semigroup a => a -> a -> a
<> Text
"\" isn't contained in the token's audience " forall a. Semigroup a => a -> a -> a
<> (String -> Text
pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) Scope
aud'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntDate
now forall a. Ord a => a -> a -> Bool
< forall a. IdTokenClaims a -> IntDate
exp IdTokenClaims a
claims')
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"received token has expired"
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. IdTokenClaims a -> Maybe ByteString
nonce IdTokenClaims a
claims' forall a. Eq a => a -> a -> Bool
== Maybe ByteString
savedNonce)
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"Inconsistent nonce"
getCurrentIntDate :: IO Jwt.IntDate
getCurrentIntDate :: IO IntDate
getCurrentIntDate = POSIXTime -> IntDate
Jwt.IntDate forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime