{-|
Module      : Network.Wai.Middleware.BearerTokenAuth
Description : Implements HTTP Bearer Token Authentication.
Copyright   : (c) Martin Bednar, 2022
License     : MIT
Maintainer  : bednam17@fit.cvut.cz
Stability   : experimental
Portability : POSIX

Implements Bearer Token Authentication as a WAI 'Middleware'.

This module is based on 'Network.Wai.Middleware.HttpAuth'.

-}
{-# LANGUAGE OverloadedStrings #-}

-- The implementation is based on 'Network.Wai.Middleware.HttpAuth'.

module Network.Wai.Middleware.BearerTokenAuth
  ( -- * Middleware
    --
    -- | You can choose from three functions to use this middleware:
    --
    -- 1. 'tokenListAuth' is the simplest to use and accepts a list of valid tokens;
    --
    -- 2. 'tokenAuth' can be used to perform a more sophisticated validation of the accepted token (such as database lookup);
    --
    -- 3. 'tokenAuth'' is similar to 'tokenAuth', but it also passes the 'Request' to the validation function.
    tokenListAuth
  , tokenAuth
  , tokenAuth'
    -- * Token validation
  , TokenValidator
  ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.Word8 (isSpace, toLower)
import Network.HTTP.Types (hAuthorization, hContentType, status401)
import Network.Wai (Middleware, Request(requestHeaders), Response, responseLBS)

-- | Type synonym for validating a token 
type TokenValidator = ByteString -> IO Bool

-- | Perform token authentication
-- based on a list of allowed tokens.
--
-- > tokenListAuth ["secret1", "secret2"]
tokenListAuth :: [ByteString] -> Middleware
tokenListAuth :: [ByteString] -> Middleware
tokenListAuth [ByteString]
tokens = TokenValidator -> Middleware
tokenAuth (\ByteString
tok -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ ByteString
tok ByteString -> [ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
tokens)

-- | Performs token authentication.
--
-- If the token is accepted, leaves the Application unchanged.
-- Otherwise, sends a @401 Unauthorized@ HTTP response.
--
-- > tokenAuth (\tok -> return $ tok == "abcd" )
tokenAuth 
  :: TokenValidator -- ^ Function that determines whether the token is valid 
  -> Middleware
tokenAuth :: TokenValidator -> Middleware
tokenAuth TokenValidator
checker = (Request -> TokenValidator) -> Middleware
tokenAuth' (TokenValidator -> Request -> TokenValidator
forall a b. a -> b -> a
const TokenValidator
checker)

-- | Like 'tokenAuth', but also passes the 'Request' to the validator function.
--
tokenAuth' 
  :: (Request -> TokenValidator) -- ^ Function that determines whether the token is valid
  -> Middleware
tokenAuth' :: (Request -> TokenValidator) -> Middleware
tokenAuth' Request -> TokenValidator
checkByReq Application
app Request
req Response -> IO ResponseReceived
sendRes = do
  let checker :: TokenValidator
checker = Request -> TokenValidator
checkByReq Request
req
  let pass :: IO ResponseReceived
pass = Application
app Request
req Response -> IO ResponseReceived
sendRes
  Bool
authorized <- TokenValidator -> Request -> IO Bool
check TokenValidator
checker Request
req
  if Bool
authorized
    then IO ResponseReceived
pass -- Pass the Application on successful auth
    else Response -> IO ResponseReceived
sendRes Response
rspUnauthorized -- Send a @401 Unauthorized@ response on failed auth

check :: TokenValidator -> Request -> IO Bool
check :: TokenValidator -> Request -> IO Bool
check TokenValidator
checkCreds Request
req =
  case Request -> Maybe ByteString
extractBearerFromRequest Request
req of
    Maybe ByteString
Nothing -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    Just ByteString
token -> TokenValidator
checkCreds ByteString
token

rspUnauthorized :: Response
rspUnauthorized :: Response
rspUnauthorized =
  Status -> ResponseHeaders -> ByteString -> Response
responseLBS
    Status
status401
    [(HeaderName
hContentType, ByteString
"text/plain"), (HeaderName
"WWW-Authenticate", ByteString
"Bearer")]
    ByteString
"Bearer token authentication is required"

extractBearerFromRequest :: Request -> Maybe ByteString
extractBearerFromRequest :: Request -> Maybe ByteString
extractBearerFromRequest Request
req = do
  ByteString
authHeader <- HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hAuthorization (Request -> ResponseHeaders
requestHeaders Request
req)
  ByteString -> Maybe ByteString
extractBearerAuth ByteString
authHeader

-- | Extract bearer authentication data from __Authorization__ header
-- value. Returns bearer token
--
-- Source: https://hackage.haskell.org/package/wai-extra-3.1.11/docs/Network-Wai-Middleware-HttpAuth.html
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth ByteString
bs =
  let (ByteString
x, ByteString
y) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break Word8 -> Bool
isSpace ByteString
bs
   in if (Word8 -> Word8) -> ByteString -> ByteString
S.map Word8 -> Word8
toLower ByteString
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"bearer"
        then ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile Word8 -> Bool
isSpace ByteString
y
        else Maybe ByteString
forall a. Maybe a
Nothing