{-# LANGUAGE OverloadedStrings #-}
{-|
    Module: Web.OIDC.Client.CodeFlow
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.CodeFlow
    (
      getAuthenticationRequestUrl
    , requestTokens

    -- * For testing
    , validateClaims
    , getCurrentIntDate
    ) where

import Control.Applicative ((<$>))
import Control.Monad (unless)
import Control.Monad.Catch (MonadThrow, throwM, MonadCatch, catch)
import Data.Aeson (decode)
import qualified Data.ByteString.Char8 as B
import Data.List (nub)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Jose.Jwt (Jwt)
import qualified Jose.Jwt as Jwt
import Network.HTTP.Client (getUri, setQueryString, applyBasicAuth, urlEncodedBody, Request(..), Manager, httpLbs, responseBody)
import Network.URI (URI)

import Web.OIDC.Client.Settings (OIDC(..))
import qualified Web.OIDC.Client.Discovery.Provider as P
import qualified Web.OIDC.Client.Internal as I
import Web.OIDC.Client.Internal (parseUrl)
import Web.OIDC.Client.Tokens (Tokens(..), IdToken(..))
import Web.OIDC.Client.Types (Scope, ScopeValue(..), Code, State, Parameters, OpenIdException(..))

-- | Make URL for Authorization Request.
getAuthenticationRequestUrl
    :: (MonadThrow m, MonadCatch m)
    => OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Maybe State      -- ^ used for CSRF mitigation. (recommended parameter)
    -> Parameters       -- ^ Optional parameters
    -> m URI
getAuthenticationRequestUrl oidc scope state params = do
    req <- parseUrl endpoint `catch` I.rethrow
    return $ getUri $ setQueryString query req
  where
    endpoint  = oidcAuthorizationSeverUrl oidc
    query     = requireds ++ state' ++ params
    requireds =
        [ ("response_type", Just "code")
        , ("client_id",     Just $ oidcClientId oidc)
        , ("redirect_uri",  Just $ oidcRedirectUri oidc)
        , ("scope",         Just $ B.pack . unwords . nub . map show $ OpenId:scope)
        ]
    state' =
        case state of
            Just _  -> [("state", state)]
            Nothing -> []

-- TODO: error response

-- | Request and validate tokens.
--
-- This function requests ID Token and Access Token to a OP's token endpoint, and validates the received ID Token.
-- Returned `Tokens` value is a valid.
--
-- If a HTTP error has occurred or a tokens validation has failed, this function throws `OpenIdException`.
requestTokens :: OIDC -> Code -> Manager -> IO Tokens
requestTokens oidc code manager = do
    json <- getTokensJson `catch` I.rethrow
    case decode json of
        Just ts -> validate oidc ts
        Nothing -> error "failed to decode tokens json" -- TODO
  where
    getTokensJson = do
        req <- parseUrl endpoint
        let req' = applyBasicAuth cid sec $ urlEncodedBody body $ req { method = "POST" }
        res <- httpLbs req' manager
        return $ responseBody res
    endpoint = oidcTokenEndpoint oidc
    cid      = oidcClientId oidc
    sec      = oidcClientSecret oidc
    redirect = oidcRedirectUri oidc
    body     =
        [ ("grant_type",   "authorization_code")
        , ("code",         code)
        , ("redirect_uri", redirect)
        ]

validate :: OIDC -> I.TokensResponse -> IO Tokens
validate oidc tres = do
    let jwt' = I.idToken tres
    validateIdToken oidc jwt'
    claims' <- getClaims jwt'
    now <- getCurrentIntDate
    validateClaims
        (P.issuer . P.configuration . oidcProvider $ oidc)
        (decodeUtf8 . oidcClientId $ oidc)
        now
        claims'
    return Tokens {
          accessToken  = I.accessToken tres
        , tokenType    = I.tokenType tres
        , idToken      = IdToken { claims = I.toIdTokenClaims claims', jwt = jwt' }
        , expiresIn    = I.expiresIn tres
        , refreshToken = I.refreshToken tres
        }

validateIdToken :: OIDC -> Jwt -> IO ()
validateIdToken oidc jwt' = do
    let jwks = P.jwkSet . oidcProvider $ oidc
        token = Jwt.unJwt jwt'
    decoded <- Jwt.decode jwks Nothing token
    case decoded of
        Right _  -> return ()
        Left err -> throwM $ JwtExceptoin err

getClaims :: MonadThrow m => Jwt -> m Jwt.JwtClaims
getClaims jwt' = case Jwt.decodeClaims (Jwt.unJwt jwt') of
                Right (_, c) -> return c
                Left  cause  -> throwM $ JwtExceptoin cause

validateClaims :: Text -> Text -> Jwt.IntDate -> Jwt.JwtClaims -> IO ()
validateClaims issuer' clientId' now claims' = do
    iss' <- getIss claims'
    unless (iss' == issuer')
        $ throwM $ ValidationException "issuer"

    aud' <- getAud claims'
    unless (clientId' `elem` aud')
        $ throwM $ ValidationException "audience"

    exp' <- getExp claims'
    unless (now < exp')
        $ throwM $ ValidationException "expire"
  where
    getIss c = get Jwt.jwtIss c "'iss' claim was not found"
    getAud c = get Jwt.jwtAud c "'aud' claim was not found"
    getExp c = get Jwt.jwtExp c "'exp' claim was not found"
    get f v msg = case f v of
        Just v' -> return v'
        Nothing -> throwM $ ValidationException msg

getCurrentIntDate :: IO Jwt.IntDate
getCurrentIntDate = Jwt.IntDate <$> getPOSIXTime