-- | Types and functions for testing 'wai' endpoints using the 'tasty' testing framework.
--
module Test.Tasty.Wai
  (
    -- * Types
    Sess (..)

    -- * Creation
  , testWai

    -- * Helpers
  , get
  , head
  , post
  , postWithHeaders
  , put
  , assertStatus'

    -- * Request Builders
  , buildRequest
  , buildRequestWithBody
  , buildRequestWithHeaders

  , module Network.Wai.Test
  ) where

import qualified Control.Exception    as E
import           Prelude              hiding (head)

import qualified Data.ByteString      as BS
import qualified Data.ByteString.Lazy as LBS
import           Data.Monoid          ((<>))

import           Network.HTTP.Types   (RequestHeaders, StdMethod)
import qualified Network.HTTP.Types   as HTTP

import           Test.HUnit.Lang      (HUnitFailure (HUnitFailure),
                                       formatFailureReason)

import           Test.Tasty.Providers (IsTest (..), Progress (..), TestName,
                                       TestTree, singleTest, testFailed,
                                       testPassed)
import           Test.Tasty.Runners   (formatMessage)

import           Network.Wai          (Application, Request, requestHeaders,
                                       requestMethod)
import           Network.Wai.Test

-- | Data structure for carrying around the info needed to build and run a test.
data Sess
  = S Application TestName (Session ())

instance IsTest Sess where
  -- No options yet
  testOptions :: Tagged Sess [OptionDescription]
testOptions = Tagged Sess [OptionDescription]
forall a. Monoid a => a
mempty

  run :: OptionSet -> Sess -> (Progress -> IO ()) -> IO Result
run OptionSet
_ (S Application
app TestName
tName Session ()
sess) Progress -> IO ()
yieldProgress = do

    -- We don't really have progress to report, so state that we're running a
    -- test but do nothing else.
    Progress -> IO ()
yieldProgress (Progress -> IO ()) -> Progress -> IO ()
forall a b. (a -> b) -> a -> b
$ TestName -> Float -> Progress
Progress (TestName
"Running " TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
tName) Float
0

    -- The wai-extra testing uses `throwIO` to indicate a test failure and
    -- converts that error into a 'String'. The result of the individual
    -- 'Session a' isn't important for the test?
    IO () -> IO (Either HUnitFailure ())
forall e a. Exception e => IO a -> IO (Either e a)
E.try (Session () -> Application -> IO ()
forall a. Session a -> Application -> IO a
runSession Session ()
sess Application
app) IO (Either HUnitFailure ())
-> (Either HUnitFailure () -> IO Result) -> IO Result
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (HUnitFailure -> IO Result)
-> (() -> IO Result) -> Either HUnitFailure () -> IO Result
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either HUnitFailure -> IO Result
toFailure () -> IO Result
forall (f :: * -> *) p. Applicative f => p -> f Result
toPass
    where
      toFailure :: HUnitFailure -> IO Result
toFailure (HUnitFailure Maybe SrcLoc
_ FailureReason
s) = TestName -> Result
testFailed (TestName -> Result) -> IO TestName -> IO Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TestName -> IO TestName
formatMessage (FailureReason -> TestName
formatFailureReason FailureReason
s))
      toPass :: p -> f Result
toPass     p
_                 = Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TestName -> Result
testPassed TestName
forall a. Monoid a => a
mempty)

-- | Create an empty 'Request' using the given HTTP Method and route.
buildRequest
  :: StdMethod
  -> BS.ByteString
  -> Request
buildRequest :: StdMethod -> ByteString -> Request
buildRequest StdMethod
mth ByteString
rpath = (Request -> ByteString -> Request)
-> ByteString -> Request -> Request
forall a b c. (a -> b -> c) -> b -> a -> c
flip Request -> ByteString -> Request
setPath ByteString
rpath (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
defaultRequest
  { requestMethod :: ByteString
requestMethod = StdMethod -> ByteString
HTTP.renderStdMethod StdMethod
mth
  }

-- | As per 'buildRequest' but requires body content.
buildRequestWithBody
  :: StdMethod
  -> BS.ByteString
  -> LBS.ByteString
  -> SRequest
buildRequestWithBody :: StdMethod -> ByteString -> ByteString -> SRequest
buildRequestWithBody StdMethod
mth ByteString
rpath =
  Request -> ByteString -> SRequest
SRequest (StdMethod -> ByteString -> Request
buildRequest StdMethod
mth ByteString
rpath)

-- | As per 'buildRequestWithBody' but allows for the setting of 'RequestHeaders'.
buildRequestWithHeaders
  :: StdMethod
  -> BS.ByteString
  -> LBS.ByteString
  -> RequestHeaders
  -> SRequest
buildRequestWithHeaders :: StdMethod -> ByteString -> ByteString -> RequestHeaders -> SRequest
buildRequestWithHeaders StdMethod
mthd ByteString
pth ByteString
bdy RequestHeaders
hdrs =
  SRequest
rq { simpleRequest :: Request
simpleRequest = (SRequest -> Request
simpleRequest SRequest
rq) { requestHeaders :: RequestHeaders
requestHeaders = RequestHeaders
hdrs } }
  where rq :: SRequest
rq = StdMethod -> ByteString -> ByteString -> SRequest
buildRequestWithBody StdMethod
mthd ByteString
pth ByteString
bdy

-- | Run a test case against a 'Application'.
--
-- This module re-exports the functions from 'wai-extra' for constructing the
-- 'Session' that is executed against a given endpoint.
--
-- A small test case may look like:
--
-- @
-- import MyApp (app)
--
-- testWai app "List Topics" $ do
--       res <- get "fudge/view"
--       assertStatus' HTTP.status200 res
-- @
--
testWai :: Application -> TestName -> Session () -> TestTree
testWai :: Application -> TestName -> Session () -> TestTree
testWai Application
a TestName
tn = TestName -> Sess -> TestTree
forall t. IsTest t => TestName -> t -> TestTree
singleTest TestName
tn (Sess -> TestTree)
-> (Session () -> Sess) -> Session () -> TestTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Application -> TestName -> Session () -> Sess
S Application
a TestName
tn

-- | Submit a 'HTTP.HEAD' request to the provided endpoint.
head :: BS.ByteString -> Session SResponse
head :: ByteString -> Session SResponse
head = Request -> Session SResponse
request (Request -> Session SResponse)
-> (ByteString -> Request) -> ByteString -> Session SResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdMethod -> ByteString -> Request
buildRequest StdMethod
HTTP.HEAD

-- | Submit a 'HTTP.GET' request to the provided endpoint.
get :: BS.ByteString -> Session SResponse
get :: ByteString -> Session SResponse
get = Request -> Session SResponse
request (Request -> Session SResponse)
-> (ByteString -> Request) -> ByteString -> Session SResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdMethod -> ByteString -> Request
buildRequest StdMethod
HTTP.GET

-- | Submit a 'HTTP.POST' request to the given endpoint with the provided
-- 'LBS.ByteString' as the body content.
post :: BS.ByteString -> LBS.ByteString -> Session SResponse
post :: ByteString -> ByteString -> Session SResponse
post ByteString
r = SRequest -> Session SResponse
srequest (SRequest -> Session SResponse)
-> (ByteString -> SRequest) -> ByteString -> Session SResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdMethod -> ByteString -> ByteString -> SRequest
buildRequestWithBody StdMethod
HTTP.POST ByteString
r

postWithHeaders :: BS.ByteString -> LBS.ByteString -> RequestHeaders -> Session SResponse
postWithHeaders :: ByteString -> ByteString -> RequestHeaders -> Session SResponse
postWithHeaders ByteString
path ByteString
body RequestHeaders
headers = SRequest -> Session SResponse
srequest (SRequest -> Session SResponse) -> SRequest -> Session SResponse
forall a b. (a -> b) -> a -> b
$ StdMethod -> ByteString -> ByteString -> RequestHeaders -> SRequest
buildRequestWithHeaders StdMethod
HTTP.POST ByteString
path ByteString
body RequestHeaders
headers

-- | Submit a 'HTTP.PUT' request to the given endpoint with the provided
-- 'LBS.ByteString' as the body content.
put :: BS.ByteString -> LBS.ByteString -> Session SResponse
put :: ByteString -> ByteString -> Session SResponse
put ByteString
r = SRequest -> Session SResponse
srequest (SRequest -> Session SResponse)
-> (ByteString -> SRequest) -> ByteString -> Session SResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdMethod -> ByteString -> ByteString -> SRequest
buildRequestWithBody StdMethod
HTTP.PUT ByteString
r

-- | An alternative helper function for checking the status code on a response
-- that lets you use the functions from 'Network.HTTP.Types' as opposed to bare
-- numbers.
assertStatus' :: HTTP.Status -> SResponse -> Session ()
assertStatus' :: Status -> SResponse -> Session ()
assertStatus' Status
c = HasCallStack => Int -> SResponse -> Session ()
Int -> SResponse -> Session ()
assertStatus (Status -> Int
HTTP.statusCode Status
c)