{-# LANGUAGE DerivingVia, DeriveAnyClass, ScopedTypeVariables, ViewPatterns #-}
-- | Middlewares for handling Twirp error responses.
module Twirp.Middleware.Errors where

import Data.Aeson
import GHC.Generics
import Network.HTTP.Types
import Network.Wai

import qualified Data.ByteString as BS

-- | A Twirp error that will be sent as a JSON-encoded response body.
-- See: https://github.com/twitchtv/twirp/blob/master/docs/errors.md
data TwirpError = TwirpError { TwirpError -> String
code :: String, TwirpError -> String
msg :: String }
  deriving stock (TwirpError -> TwirpError -> Bool
(TwirpError -> TwirpError -> Bool)
-> (TwirpError -> TwirpError -> Bool) -> Eq TwirpError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TwirpError -> TwirpError -> Bool
== :: TwirpError -> TwirpError -> Bool
$c/= :: TwirpError -> TwirpError -> Bool
/= :: TwirpError -> TwirpError -> Bool
Eq, Int -> TwirpError -> ShowS
[TwirpError] -> ShowS
TwirpError -> String
(Int -> TwirpError -> ShowS)
-> (TwirpError -> String)
-> ([TwirpError] -> ShowS)
-> Show TwirpError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TwirpError -> ShowS
showsPrec :: Int -> TwirpError -> ShowS
$cshow :: TwirpError -> String
show :: TwirpError -> String
$cshowList :: [TwirpError] -> ShowS
showList :: [TwirpError] -> ShowS
Show, (forall x. TwirpError -> Rep TwirpError x)
-> (forall x. Rep TwirpError x -> TwirpError) -> Generic TwirpError
forall x. Rep TwirpError x -> TwirpError
forall x. TwirpError -> Rep TwirpError x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. TwirpError -> Rep TwirpError x
from :: forall x. TwirpError -> Rep TwirpError x
$cto :: forall x. Rep TwirpError x -> TwirpError
to :: forall x. Rep TwirpError x -> TwirpError
Generic)
  deriving anyclass (Value -> Parser [TwirpError]
Value -> Parser TwirpError
(Value -> Parser TwirpError)
-> (Value -> Parser [TwirpError]) -> FromJSON TwirpError
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
$cparseJSON :: Value -> Parser TwirpError
parseJSON :: Value -> Parser TwirpError
$cparseJSONList :: Value -> Parser [TwirpError]
parseJSONList :: Value -> Parser [TwirpError]
FromJSON, [TwirpError] -> Encoding
[TwirpError] -> Value
TwirpError -> Encoding
TwirpError -> Value
(TwirpError -> Value)
-> (TwirpError -> Encoding)
-> ([TwirpError] -> Value)
-> ([TwirpError] -> Encoding)
-> ToJSON TwirpError
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
$ctoJSON :: TwirpError -> Value
toJSON :: TwirpError -> Value
$ctoEncoding :: TwirpError -> Encoding
toEncoding :: TwirpError -> Encoding
$ctoJSONList :: [TwirpError] -> Value
toJSONList :: [TwirpError] -> Value
$ctoEncodingList :: [TwirpError] -> Encoding
toEncodingList :: [TwirpError] -> Encoding
ToJSON)

-- | Rewrite error responses to use Twirp's error codes and JSON encoding
-- when they don't already fit that model.
twirpErrorResponses :: Middleware
twirpErrorResponses :: Middleware
twirpErrorResponses = (Response -> Response) -> Middleware
modifyResponse ((Response -> Response) -> Middleware)
-> (Response -> Response) -> Middleware
forall a b. (a -> b) -> a -> b
$ \Response
response ->
  if Response -> Bool
nonTwirpError Response
response then
    let
      status :: Status
status = Response -> Status
responseStatus Response
response
      errResponse :: a -> Response
errResponse a
err = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status ResponseHeaders
headers (a -> ByteString
forall a. ToJSON a => a -> ByteString
encode a
err)
    in case Status -> Int
statusCode Status
status of
      Int
400 -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
badRequest
      Int
404 -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
notFound
      Int
408 -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
canceled
      Int
500 -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
serverError
      Int
503 -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
unavailable
      Int
_   -> TwirpError -> Response
forall {a}. ToJSON a => a -> Response
errResponse TwirpError
unknown
  else
    Response
response
  where
    nonTwirpError :: Response -> Bool
nonTwirpError Response
r = Response -> Bool
isError Response
r Bool -> Bool -> Bool
&& Bool -> Bool
not (Response -> Bool
isJson Response
r)
    isError :: Response -> Bool
isError (Response -> Status
responseStatus -> Status
s) = Status -> Bool
statusIsClientError Status
s Bool -> Bool -> Bool
|| Status -> Bool
statusIsServerError Status
s
    isJson :: Response -> Bool
isJson (Response -> ResponseHeaders
responseHeaders -> ResponseHeaders
hs) = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString
"application/json" ByteString -> ByteString -> Bool
`BS.isPrefixOf`) (Maybe ByteString -> Bool) -> Maybe ByteString -> Bool
forall a b. (a -> b) -> a -> b
$
      HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentType ResponseHeaders
hs

    headers :: ResponseHeaders
headers = [(HeaderName
hContentType, ByteString
"application/json; charset=utf-8")]

    badRequest :: TwirpError
badRequest = String -> String -> TwirpError
TwirpError String
"invalid_argument" String
"Bad Request"
    notFound :: TwirpError
notFound = String -> String -> TwirpError
TwirpError String
"not_found" String
"Not found"
    canceled :: TwirpError
canceled = String -> String -> TwirpError
TwirpError String
"canceled" String
"Request Timeout"
    serverError :: TwirpError
serverError = String -> String -> TwirpError
TwirpError String
"internal" String
"Internal Server Error"
    unavailable :: TwirpError
unavailable = String -> String -> TwirpError
TwirpError String
"unavailable" String
"Service Unavailable"
    unknown :: TwirpError
unknown = String -> String -> TwirpError
TwirpError String
"unknown" String
"Unknown"