{-# LANGUAGE OverloadedStrings #-}
module Network.OAuth2.JWT.Client.AuthorizationGrant (
GrantError (..)
, sign
, refresh
, local
, grant
) where
import qualified Control.Concurrent.MVar as MVar
import Control.Lens ((.~), (&))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Bifunctor (BifunctorTrans (..))
import Control.Monad.Trans.Except (ExceptT (..), runExceptT)
import Crypto.JWT (JWK, JWTError)
import qualified Crypto.JWT as JWT
import qualified Data.Aeson as Aeson
import Data.Bifunctor as X (Bifunctor(..))
import qualified Data.ByteString.Lazy as LazyByteString
import qualified Data.HashMap.Strict as HashMap
import Data.String (IsString (..))
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Data.Time (UTCTime)
import qualified Data.Time as Time
import Network.OAuth2.JWT.Client.Data
import qualified Network.OAuth2.JWT.Client.Serial as Serial
import qualified Network.HTTP.Client as HTTP
import qualified Network.HTTP.Types as HTTP
data GrantError =
SerialisationGrantError Text
| JWTGrantError JWT.JWTError
| EndpointGrantError Text
| StatusGrantError Int Text
deriving (Eq, Show)
grant :: Store -> IO (Either GrantError AccessToken)
grant (Store manager endpoint claims jwk store) = do
now <- Time.getCurrentTime
t <- local now <$> MVar.readMVar store
case t of
Just token ->
pure . Right $ token
Nothing -> do
MVar.modifyMVar store $ \state -> do
case local now state of
Just token ->
pure (state, Right token)
Nothing ->
runExceptT (refresh now manager endpoint claims jwk) >>= \e -> case e of
Left err ->
pure (state, Left err)
Right (Response token expiry) ->
pure (HasToken token (Time.addUTCTime (getExpiresIn expiry) now), Right token)
local :: UTCTime -> TokenState -> Maybe AccessToken
local now state =
case state of
HasToken token time | now < time ->
Just token
HasToken _ _ ->
Nothing
NoToken ->
Nothing
refresh :: UTCTime -> HTTP.Manager -> TokenEndpoint -> Claims -> JWK -> ExceptT GrantError IO Response
refresh now manager endpoint claims jwk = do
assertion <- firstT JWTGrantError $
sign now claims jwk
req <- ExceptT . pure . first (EndpointGrantError . Text.pack . show) $
HTTP.parseRequest (Text.unpack . getTokenEndpoint $ endpoint)
res <- liftIO $ flip HTTP.httpLbs manager $
HTTP.urlEncodedBody [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
, ("assertion", getAssertion assertion)
] $ req { HTTP.requestHeaders = [
("Accept", "application/json")
] }
case HTTP.statusCode . HTTP.responseStatus $ res of
200 ->
ExceptT . pure . first SerialisationGrantError $
Serial.response (HTTP.responseBody res)
status ->
ExceptT . pure . Left $ StatusGrantError status (Text.decodeUtf8 . LazyByteString.toStrict . HTTP.responseBody $ res)
sign :: UTCTime -> Claims -> JWK -> ExceptT JWTError IO Assertion
sign now (Claims issuer subject audience scopes expires custom) jwk = do
let
format =
fromString . Text.unpack
header =
JWT.newJWSHeader ((), JWT.RS256)
& JWT.typ .~ Just (JWT.HeaderParam () "JWT")
claims =
JWT.emptyClaimsSet
& JWT.claimIss .~ Just (format . getIssuer $ issuer)
& JWT.claimSub .~ fmap (format . getSubject) subject
& JWT.claimAud .~ Just (JWT.Audience [format . getAudience $ audience])
& JWT.claimIat .~ Just (JWT.NumericDate now)
& JWT.claimExp .~ Just (JWT.NumericDate $ Time.addUTCTime (getExpiresIn expires) now)
& JWT.unregisteredClaims .~ (HashMap.fromList $ [
("scope", Aeson.toJSON . Text.intercalate " " $ getScope <$> scopes)
] ++ custom)
signed <- JWT.signClaims jwk header claims
pure . Assertion . LazyByteString.toStrict . JWT.encodeCompact $ signed