{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
module Network.Wai.Handler.Warp.Run where
import "iproute" Data.IP (toHostAddress, toHostAddress6)
import Control.Arrow (first)
import qualified Control.Concurrent as Conc (yield)
import Control.Exception as E
import qualified Data.ByteString as S
import Data.Char (chr)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Streaming.Network (bindPortTCP)
import Foreign.C.Error (Errno(..), eCONNABORTED)
import GHC.IO.Exception (IOException(..))
import qualified Network.HTTP2 as H2
import Network.Socket (Socket, close, accept, withSocketsDo, SockAddr(SockAddrInet, SockAddrInet6), setSocketOption, SocketOption(..))
#if MIN_VERSION_network(3,1,1)
import Network.Socket (gracefulClose)
#endif
import qualified Network.Socket.ByteString as Sock
import Network.Wai
import Network.Wai.Internal (ResponseReceived (ResponseReceived))
import System.Environment (lookupEnv)
import qualified System.TimeManager as T
import System.Timeout (timeout)
import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Counter
import qualified Network.Wai.Handler.Warp.Date as D
import qualified Network.Wai.Handler.Warp.FdCache as F
import qualified Network.Wai.Handler.Warp.FileInfoCache as I
import Network.Wai.Handler.Warp.HTTP2 (http2)
import Network.Wai.Handler.Warp.HTTP2.Types (isHTTP2)
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.Imports hiding (readInt)
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.Recv
import Network.Wai.Handler.Warp.Request
import Network.Wai.Handler.Warp.Response
import Network.Wai.Handler.Warp.SendFile
import Network.Wai.Handler.Warp.Settings
import Network.Wai.Handler.Warp.Types
#if WINDOWS
import Network.Wai.Handler.Warp.Windows
#else
import Network.Socket (fdSocket)
#endif
socketConnection :: Settings -> Socket -> IO Connection
socketConnection set s = do
bufferPool <- newBufferPool
writeBuf <- allocateBuffer bufferSize
let sendall = Sock.sendAll s
isH2 <- newIORef False
return Connection {
connSendMany = Sock.sendMany s
, connSendAll = sendall
, connSendFile = sendFile s writeBuf bufferSize sendall
#if MIN_VERSION_network(3,1,1)
, connClose = do
h2 <- readIORef isH2
let tm = if h2 then settingsGracefulCloseTimeout2 set
else settingsGracefulCloseTimeout1 set
if tm == 0 then
close s
else
gracefulClose s tm `E.catch` \(E.SomeException _) -> return ()
#else
, connClose = close s
#endif
, connFree = freeBuffer writeBuf
, connRecv = receive s bufferPool
, connRecvBuf = receiveBuf s
, connWriteBuffer = writeBuf
, connBufferSize = bufferSize
, connHTTP2 = isH2
}
-- | Run an 'Application' on the given port.
-- This calls 'runSettings' with 'defaultSettings'.
run :: Port -> Application -> IO ()
run p = runSettings defaultSettings { settingsPort = p }
-- | Run an 'Application' on the port present in the @PORT@
-- environment variable. Uses the 'Port' given when the variable is unset.
-- This calls 'runSettings' with 'defaultSettings'.
--
-- Since 3.0.9
runEnv :: Port -> Application -> IO ()
runEnv p app = do
mp <- lookupEnv "PORT"
maybe (run p app) runReadPort mp
where
runReadPort :: String -> IO ()
runReadPort sp = case reads sp of
((p', _):_) -> run p' app
_ -> fail $ "Invalid value in $PORT: " ++ sp
-- | Run an 'Application' with the given 'Settings'.
-- This opens a listen socket on the port defined in 'Settings' and
-- calls 'runSettingsSocket'.
runSettings :: Settings -> Application -> IO ()
runSettings set app = withSocketsDo $
bracket
(bindPortTCP (settingsPort set) (settingsHost set))
close
(\socket -> do
setSocketCloseOnExec socket
runSettingsSocket set socket app)
-- | This installs a shutdown handler for the given socket and
-- calls 'runSettingsConnection' with the default connection setup action
-- which handles plain (non-cipher) HTTP.
-- When the listen socket in the second argument is closed, all live
-- connections are gracefully shut down.
--
-- The supplied socket can be a Unix named socket, which
-- can be used when reverse HTTP proxying into your application.
--
-- Note that the 'settingsPort' will still be passed to 'Application's via the
-- 'serverPort' record.
runSettingsSocket :: Settings -> Socket -> Application -> IO ()
runSettingsSocket set socket app = do
settingsInstallShutdownHandler set closeListenSocket
runSettingsConnection set getConn app
where
getConn = do
#if WINDOWS
(s, sa) <- windowsThreadBlockHack $ accept socket
#else
(s, sa) <- accept socket
#endif
setSocketCloseOnExec s
setSocketOption s NoDelay 1 `E.catch` \(E.SomeException _) -> return ()
conn <- socketConnection set s
return (conn, sa)
closeListenSocket = close socket
runSettingsConnection :: Settings -> IO (Connection, SockAddr) -> Application -> IO ()
runSettingsConnection set getConn app = runSettingsConnectionMaker set getConnMaker app
where
getConnMaker = do
(conn, sa) <- getConn
return (return conn, sa)
runSettingsConnectionMaker :: Settings -> IO (IO Connection, SockAddr) -> Application -> IO ()
runSettingsConnectionMaker x y =
runSettingsConnectionMakerSecure x (toTCP <$> y)
where
toTCP = first ((, TCP) <$>)
runSettingsConnectionMakerSecure :: Settings -> IO (IO (Connection, Transport), SockAddr) -> Application -> IO ()
runSettingsConnectionMakerSecure set getConnMaker app = do
settingsBeforeMainLoop set
counter <- newCounter
withII $ acceptConnection set getConnMaker app counter
where
withII action =
withTimeoutManager $ \tm ->
D.withDateCache $ \dc ->
F.withFdCache fdCacheDurationInSeconds $ \fdc ->
I.withFileInfoCache fdFileInfoDurationInSeconds $ \fic -> do
let ii = InternalInfo tm dc fdc fic
action ii
!fdCacheDurationInSeconds = settingsFdCacheDuration set * 1000000
!fdFileInfoDurationInSeconds = settingsFileInfoCacheDuration set * 1000000
!timeoutInSeconds = settingsTimeout set * 1000000
withTimeoutManager f = case settingsManager set of
Just tm -> f tm
Nothing -> bracket
(T.initialize timeoutInSeconds)
T.stopManager
f
acceptConnection :: Settings
-> IO (IO (Connection, Transport), SockAddr)
-> Application
-> Counter
-> InternalInfo
-> IO ()
acceptConnection set getConnMaker app counter ii = do
void $ mask_ acceptLoop
gracefulShutdown set counter
where
acceptLoop = do
allowInterrupt
mx <- acceptNewConnection
case mx of
Nothing -> return ()
Just (mkConn, addr) -> do
fork set mkConn addr app counter ii
acceptLoop
acceptNewConnection = do
ex <- try getConnMaker
case ex of
Right x -> return $ Just x
Left e -> do
let eConnAborted = getErrno eCONNABORTED
getErrno (Errno cInt) = cInt
if ioe_errno e == Just eConnAborted
then acceptNewConnection
else do
settingsOnException set Nothing $ toException e
return Nothing
fork :: Settings
-> IO (Connection, Transport)
-> SockAddr
-> Application
-> Counter
-> InternalInfo
-> IO ()
fork set mkConn addr app counter ii = settingsFork set $ \unmask ->
handle (settingsOnException set Nothing) $
bracket mkConn cleanUp (serve unmask)
where
cleanUp (conn, _) = connClose conn `finally` connFree conn
serve unmask (conn, transport) = bracket register cancel $ \th -> do
unmask .
bracket (onOpen addr) (onClose addr) $ \goingon ->
when goingon $ serveConnection conn ii th addr transport set app
where
register = T.registerKillThread (timeoutManager ii) (connClose conn)
cancel = T.cancel
onOpen adr = increase counter >> settingsOnOpen set adr
onClose adr _ = decrease counter >> settingsOnClose set adr
serveConnection :: Connection
-> InternalInfo
-> T.Handle
-> SockAddr
-> Transport
-> Settings
-> Application
-> IO ()
serveConnection conn ii th origAddr transport settings app = do
(h2,bs) <- if isHTTP2 transport then
return (True, "")
else do
bs0 <- connRecv conn
if S.length bs0 >= 4 && "PRI " `S.isPrefixOf` bs0 then
return (True, bs0)
else
return (False, bs0)
istatus <- newIORef False
if settingsHTTP2Enabled settings && h2 then do
rawRecvN <- makeReceiveN bs (connRecv conn) (connRecvBuf conn)
let recvN = wrappedRecvN th istatus (settingsSlowlorisSize settings) rawRecvN
sendBS x = connSendAll conn x >> T.tickle th
checkTLS
setConnHTTP2 conn True
http2 conn transport ii origAddr settings recvN sendBS app
else do
src <- mkSource (wrappedRecv conn th istatus (settingsSlowlorisSize settings))
writeIORef istatus True
leftoverSource src bs
addr <- getProxyProtocolAddr src
http1 True addr istatus src `E.catch` \e ->
case () of
()
| Just NoKeepAliveRequest <- fromException e -> return ()
| Just (BadFirstLine _) <- fromException e -> return ()
| otherwise -> do
_ <- sendErrorResponse (dummyreq addr) istatus e
throwIO e
where
getProxyProtocolAddr src =
case settingsProxyProtocol settings of
ProxyProtocolNone ->
return origAddr
ProxyProtocolRequired -> do
seg <- readSource src
parseProxyProtocolHeader src seg
ProxyProtocolOptional -> do
seg <- readSource src
if S.isPrefixOf "PROXY " seg
then parseProxyProtocolHeader src seg
else do leftoverSource src seg
return origAddr
parseProxyProtocolHeader src seg = do
let (header,seg') = S.break (== 0x0d) seg
maybeAddr = case S.split 0x20 header of
["PROXY","TCP4",clientAddr,_,clientPort,_] ->
case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
[a] -> Just (SockAddrInet (readInt clientPort)
(toHostAddress a))
_ -> Nothing
["PROXY","TCP6",clientAddr,_,clientPort,_] ->
case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
[a] -> Just (SockAddrInet6 (readInt clientPort)
0
(toHostAddress6 a)
0)
_ -> Nothing
("PROXY":"UNKNOWN":_) ->
Just origAddr
_ ->
Nothing
case maybeAddr of
Nothing -> throwIO (BadProxyHeader (decodeAscii header))
Just a -> do leftoverSource src (S.drop 2 seg')
return a
decodeAscii = map (chr . fromEnum) . S.unpack
shouldSendErrorResponse se
| Just ConnectionClosedByPeer <- fromException se = False
| otherwise = True
sendErrorResponse req istatus e = do
status <- readIORef istatus
if shouldSendErrorResponse e && status
then do
sendResponse settings conn ii th req defaultIndexRequestHeader (return S.empty) (errorResponse e)
else return False
dummyreq addr = defaultRequest { remoteHost = addr }
errorResponse e = settingsOnExceptionResponse settings e
http1 firstRequest addr istatus src = do
(req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport
keepAlive <- processRequest istatus src req mremainingRef idxhdr nextBodyFlush
`E.catch` \e -> do
settingsOnException settings (Just req) e
return False
when keepAlive $ http1 False addr istatus src
processRequest istatus src req mremainingRef idxhdr nextBodyFlush = do
T.pause th
keepAliveRef <- newIORef $ error "keepAliveRef not filled"
r <- E.try $ app req $ \res -> do
T.resume th
writeIORef istatus False
keepAlive <- sendResponse settings conn ii th req idxhdr (readSource src) res
writeIORef keepAliveRef keepAlive
return ResponseReceived
case r of
Right ResponseReceived -> return ()
Left e@(SomeException _)
| Just (ExceptionInsideResponseBody e') <- fromException e -> throwIO e'
| otherwise -> do
keepAlive <- sendErrorResponse req istatus e
settingsOnException settings (Just req) e
writeIORef keepAliveRef keepAlive
keepAlive <- readIORef keepAliveRef
Conc.yield
if keepAlive
then
case settingsMaximumBodyFlush settings of
Nothing -> do
flushEntireBody nextBodyFlush
T.resume th
return True
Just maxToRead -> do
let tryKeepAlive = do
isComplete <- flushBody nextBodyFlush maxToRead
if isComplete then do
T.resume th
return True
else
return False
case mremainingRef of
Just ref -> do
remaining <- readIORef ref
if remaining <= maxToRead then
tryKeepAlive
else
return False
Nothing -> tryKeepAlive
else
return False
checkTLS = case transport of
TCP -> return ()
tls -> unless (tls12orLater tls) $ goaway conn H2.InadequateSecurity "Weak TLS"
tls12orLater tls = tlsMajorVersion tls == 3 && tlsMinorVersion tls >= 3
goaway :: Connection -> H2.ErrorCodeId -> ByteString -> IO ()
goaway Connection{..} etype debugmsg = connSendAll bytestream
where
einfo = H2.encodeInfo id 0
frame = H2.GoAwayFrame 0 etype debugmsg
bytestream = H2.encodeFrame einfo frame
flushEntireBody :: IO ByteString -> IO ()
flushEntireBody src =
loop
where
loop = do
bs <- src
unless (S.null bs) loop
flushBody :: IO ByteString
-> Int
-> IO Bool
flushBody src =
loop
where
loop toRead = do
bs <- src
let toRead' = toRead - S.length bs
case () of
()
| S.null bs -> return True
| toRead' >= 0 -> loop toRead'
| otherwise -> return False
wrappedRecv :: Connection -> T.Handle -> IORef Bool -> Int -> IO ByteString
wrappedRecv Connection { connRecv = recv } th istatus slowlorisSize = do
bs <- recv
unless (S.null bs) $ do
writeIORef istatus True
when (S.length bs >= slowlorisSize) $ T.tickle th
return bs
wrappedRecvN :: T.Handle -> IORef Bool -> Int -> (BufSize -> IO ByteString) -> (BufSize -> IO ByteString)
wrappedRecvN th istatus slowlorisSize readN bufsize = do
bs <- readN bufsize
unless (S.null bs) $ do
writeIORef istatus True
when (S.length bs >= slowlorisSize || bufsize <= slowlorisSize) $ T.tickle th
return bs
setSocketCloseOnExec :: Socket -> IO ()
#if WINDOWS
setSocketCloseOnExec _ = return ()
#else
setSocketCloseOnExec socket = do
#if MIN_VERSION_network(3,0,0)
fd <- fdSocket socket
#else
let fd = fdSocket socket
#endif
F.setFileCloseOnExec $ fromIntegral fd
#endif
gracefulShutdown :: Settings -> Counter -> IO ()
gracefulShutdown set counter =
case settingsGracefulShutdownTimeout set of
Nothing ->
waitForZero counter
(Just seconds) ->
void (timeout (seconds * microsPerSecond) (waitForZero counter))
where microsPerSecond = 1000000