{-# LANGUAGE CPP #-}
-- | Redirect non-SSL requests to https
--
-- Since 3.0.7
module  Network.Wai.Middleware.ForceSSL
    ( forceSSL
    ) where

import Network.Wai
import Network.Wai.Request

#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
import Data.Monoid (mempty)
#endif

import Data.Monoid ((<>))
import Network.HTTP.Types (hLocation, methodGet, status301, status307)

-- | For requests that don't appear secure, redirect to https
--
-- Since 3.0.7
forceSSL :: Middleware
forceSSL :: Middleware
forceSSL Application
app Request
req Response -> IO ResponseReceived
sendResponse =
    case (Request -> Bool
appearsSecure Request
req, Request -> Maybe Response
redirectResponse Request
req) of
        (Bool
False, Just Response
resp) -> Response -> IO ResponseReceived
sendResponse Response
resp
        (Bool, Maybe Response)
_                  -> Application
app Request
req Response -> IO ResponseReceived
sendResponse

redirectResponse :: Request -> Maybe Response
redirectResponse :: Request -> Maybe Response
redirectResponse Request
req = do
  ByteString
host <- Request -> Maybe ByteString
requestHeaderHost Request
req
  Response -> Maybe Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [(HeaderName
hLocation, ByteString -> ByteString
location ByteString
host)] Builder
forall a. Monoid a => a
mempty
  where
    location :: ByteString -> ByteString
location ByteString
h = ByteString
"https://" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
    status :: Status
status
        | Request -> ByteString
requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
methodGet = Status
status301
        | Bool
otherwise = Status
status307