{-# LANGUAGE CPP #-}
module Network.Wai.Middleware.ForceSSL
( forceSSL
) where
import Network.Wai
import Network.Wai.Request
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
import Data.Monoid (mempty)
#endif
import Data.Monoid ((<>))
import Network.HTTP.Types (hLocation, methodGet, status301, status307)
forceSSL :: Middleware
forceSSL :: Middleware
forceSSL Application
app Request
req Response -> IO ResponseReceived
sendResponse =
case (Request -> Bool
appearsSecure Request
req, Request -> Maybe Response
redirectResponse Request
req) of
(Bool
False, Just Response
resp) -> Response -> IO ResponseReceived
sendResponse Response
resp
(Bool, Maybe Response)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
redirectResponse :: Request -> Maybe Response
redirectResponse :: Request -> Maybe Response
redirectResponse Request
req = do
ByteString
host <- Request -> Maybe ByteString
requestHeaderHost Request
req
Response -> Maybe Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [(HeaderName
hLocation, ByteString -> ByteString
location ByteString
host)] Builder
forall a. Monoid a => a
mempty
where
location :: ByteString -> ByteString
location ByteString
h = ByteString
"https://" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
status :: Status
status
| Request -> ByteString
requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
methodGet = Status
status301
| Bool
otherwise = Status
status307