{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}

-- | Servant server authentication.
module Servant.Auth.Hmac.Server (
    HmacAuth,
    HmacAuthContextHandlers,
    HmacAuthContext,
    HmacAuthHandler,
    hmacAuthServerContext,
    hmacAuthHandler,
    hmacAuthHandlerMap,
) where

import Control.Monad.Except (throwError)
import Data.ByteString (ByteString)
import Data.Maybe (fromMaybe)
import Network.Wai (rawPathInfo, rawQueryString, requestHeaderHost, requestHeaders, requestMethod)
import Servant (Context (EmptyContext, (:.)))
import Servant.API (AuthProtect)
import Servant.Server (Handler, err401, errBody)
import Servant.Server.Experimental.Auth (AuthHandler, AuthServerData, mkAuthHandler)

import Servant.Auth.Hmac.Crypto (
    RequestPayload (..),
    SecretKey,
    Signature,
    keepWhitelistedHeaders,
    verifySignatureHmac,
 )

import qualified Network.Wai as Wai (Request)

type HmacAuth = AuthProtect "hmac-auth"

type instance AuthServerData HmacAuth = ()

type HmacAuthHandler = AuthHandler Wai.Request ()
type HmacAuthContextHandlers = '[HmacAuthHandler]
type HmacAuthContext = Context HmacAuthContextHandlers

hmacAuthServerContext ::
    -- | Signing function
    (SecretKey -> ByteString -> Signature) ->
    -- | Secret key that was used for signing 'Request'
    SecretKey ->
    HmacAuthContext
hmacAuthServerContext :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthContext
hmacAuthServerContext SecretKey -> ByteString -> Signature
signer SecretKey
sk = (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthHandler
hmacAuthHandler SecretKey -> ByteString -> Signature
signer SecretKey
sk forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context '[]
EmptyContext

-- | Create 'HmacAuthHandler' from signing function and secret key.
hmacAuthHandler ::
    -- | Signing function
    (SecretKey -> ByteString -> Signature) ->
    -- | Secret key that was used for signing 'Request'
    SecretKey ->
    HmacAuthHandler
hmacAuthHandler :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthHandler
hmacAuthHandler = (Request -> Handler Request)
-> (SecretKey -> ByteString -> Signature)
-> SecretKey
-> HmacAuthHandler
hmacAuthHandlerMap forall (f :: * -> *) a. Applicative f => a -> f a
pure

{- | Like 'hmacAuthHandler' but allows to specify additional mapping function
for 'Wai.Request'. This can be useful if you want to print incoming request (for
logging purposes) or filter some headers (to match signature). Given function is
applied before signature verification.
-}
hmacAuthHandlerMap ::
    -- | Request mapper
    (Wai.Request -> Handler Wai.Request) ->
    -- | Signing function
    (SecretKey -> ByteString -> Signature) ->
    -- | Secret key that was used for signing 'Request'
    SecretKey ->
    HmacAuthHandler
hmacAuthHandlerMap :: (Request -> Handler Request)
-> (SecretKey -> ByteString -> Signature)
-> SecretKey
-> HmacAuthHandler
hmacAuthHandlerMap Request -> Handler Request
mapper SecretKey -> ByteString -> Signature
signer SecretKey
sk = forall r usr. (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler Request -> Handler ()
handler
  where
    handler :: Wai.Request -> Handler ()
    handler :: Request -> Handler ()
handler Request
req = do
        Request
newReq <- Request -> Handler Request
mapper Request
req
        let payload :: RequestPayload
payload = Request -> RequestPayload
waiRequestToPayload Request
newReq
        let verification :: Maybe ByteString
verification = (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Maybe ByteString
verifySignatureHmac SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
payload
        case Maybe ByteString
verification of
            Maybe ByteString
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just ByteString
bs -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ ServerError
err401{errBody :: ByteString
errBody = ByteString
bs}

----------------------------------------------------------------------------
-- Internals
----------------------------------------------------------------------------

-- getWaiRequestBody :: Wai.Request -> IO ByteString
-- getWaiRequestBody request = BS.concat <$> getChunks
--   where
--     getChunks :: IO [ByteString]
--     getChunks = requestBody request >>= \chunk ->
--         if chunk == BS.empty
--         then pure []
--         else (chunk:) <$> getChunks

waiRequestToPayload :: Wai.Request -> RequestPayload
-- waiRequestToPayload req = getWaiRequestBody req >>= \body -> pure RequestPayload
waiRequestToPayload :: Request -> RequestPayload
waiRequestToPayload Request
req =
    RequestPayload
        { rpMethod :: ByteString
rpMethod = Request -> ByteString
requestMethod Request
req
        , rpContent :: ByteString
rpContent = ByteString
""
        , rpHeaders :: RequestHeaders
rpHeaders = RequestHeaders -> RequestHeaders
keepWhitelistedHeaders forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req
        , rpRawUrl :: ByteString
rpRawUrl = forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty (Request -> Maybe ByteString
requestHeaderHost Request
req) forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
        }