{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Connection
( PendingConnection (..)
, acceptRequest
, AcceptRequest(..)
, defaultAcceptRequest
, acceptRequestWith
, rejectRequest
, RejectRequest(..)
, defaultRejectRequest
, rejectRequestWith
, Connection (..)
, ConnectionOptions (..)
, defaultConnectionOptions
, receive
, receiveDataMessage
, receiveData
, send
, sendDataMessage
, sendDataMessages
, sendTextData
, sendTextDatas
, sendBinaryData
, sendBinaryDatas
, sendClose
, sendCloseCode
, sendPing
, forkPingThread
, CompressionOptions (..)
, PermessageDeflate (..)
, defaultPermessageDeflate
, SizeLimit (..)
) where
import qualified Data.ByteString.Builder as Builder
import Control.Applicative ((<$>))
import Control.Concurrent (forkIO,
threadDelay)
import Control.Exception (AsyncException,
fromException,
handle,
throwIO)
import Control.Monad (foldM, unless,
when)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.IORef (IORef,
newIORef,
readIORef,
writeIORef)
import Data.List (find)
import Data.Maybe (catMaybes)
import qualified Data.Text as T
import Data.Word (Word16)
import Prelude
import Network.WebSockets.Connection.Options
import Network.WebSockets.Extensions as Extensions
import Network.WebSockets.Extensions.PermessageDeflate
import Network.WebSockets.Extensions.StrictUnicode
import Network.WebSockets.Http
import Network.WebSockets.Protocol
import Network.WebSockets.Stream (Stream)
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Types
data PendingConnection = PendingConnection
{ pendingOptions :: !ConnectionOptions
, pendingRequest :: !RequestHead
, pendingOnAccept :: !(Connection -> IO ())
, pendingStream :: !Stream
}
data AcceptRequest = AcceptRequest
{ acceptSubprotocol :: !(Maybe B.ByteString)
, acceptHeaders :: !Headers
}
defaultAcceptRequest :: AcceptRequest
defaultAcceptRequest = AcceptRequest Nothing []
sendResponse :: PendingConnection -> Response -> IO ()
sendResponse pc rsp = Stream.write (pendingStream pc)
(Builder.toLazyByteString (encodeResponse rsp))
acceptRequest :: PendingConnection -> IO Connection
acceptRequest pc = acceptRequestWith pc defaultAcceptRequest
acceptRequestWith :: PendingConnection -> AcceptRequest -> IO Connection
acceptRequestWith pc ar = case find (flip compatible request) protocols of
Nothing -> do
sendResponse pc $ response400 versionHeader ""
throwIO NotSupported
Just protocol -> do
rqExts <- either throwIO return $
getRequestSecWebSocketExtensions request
pmdExt <- case connectionCompressionOptions (pendingOptions pc) of
NoCompression -> return Nothing
PermessageDeflateCompression pmd0 ->
case negotiateDeflate (connectionMessageDataSizeLimit options) (Just pmd0) rqExts of
Left err -> do
rejectRequestWith pc defaultRejectRequest {rejectMessage = B8.pack err}
throwIO NotSupported
Right pmd1 -> return (Just pmd1)
let unicodeExt =
if connectionStrictUnicode (pendingOptions pc)
then Just strictUnicode else Nothing
let exts = catMaybes [pmdExt, unicodeExt]
let subproto = maybe [] (\p -> [("Sec-WebSocket-Protocol", p)]) $ acceptSubprotocol ar
headers = subproto ++ acceptHeaders ar ++ concatMap extHeaders exts
response = finishRequest protocol request headers
either throwIO (sendResponse pc) response
parseRaw <- decodeMessages
protocol
(connectionFramePayloadSizeLimit options)
(connectionMessageDataSizeLimit options)
(pendingStream pc)
writeRaw <- encodeMessages protocol ServerConnection (pendingStream pc)
write <- foldM (\x ext -> extWrite ext x) writeRaw exts
parse <- foldM (\x ext -> extParse ext x) parseRaw exts
sentRef <- newIORef False
let connection = Connection
{ connectionOptions = options
, connectionType = ServerConnection
, connectionProtocol = protocol
, connectionParse = parse
, connectionWrite = write
, connectionSentClose = sentRef
}
pendingOnAccept pc connection
return connection
where
options = pendingOptions pc
request = pendingRequest pc
versionHeader = [("Sec-WebSocket-Version",
B.intercalate ", " $ concatMap headerVersions protocols)]
data RejectRequest = RejectRequest
{
rejectCode :: !Int
,
rejectMessage :: !B.ByteString
,
rejectHeaders :: Headers
,
rejectBody :: !B.ByteString
}
defaultRejectRequest :: RejectRequest
defaultRejectRequest = RejectRequest
{ rejectCode = 400
, rejectMessage = "Bad Request"
, rejectHeaders = []
, rejectBody = ""
}
rejectRequestWith
:: PendingConnection
-> RejectRequest
-> IO ()
rejectRequestWith pc reject = sendResponse pc $ Response
ResponseHead
{ responseCode = rejectCode reject
, responseMessage = rejectMessage reject
, responseHeaders = rejectHeaders reject
}
(rejectBody reject)
rejectRequest
:: PendingConnection
-> B.ByteString
-> IO ()
rejectRequest pc body = rejectRequestWith pc
defaultRejectRequest {rejectBody = body}
data Connection = Connection
{ connectionOptions :: !ConnectionOptions
, connectionType :: !ConnectionType
, connectionProtocol :: !Protocol
, connectionParse :: !(IO (Maybe Message))
, connectionWrite :: !([Message] -> IO ())
, connectionSentClose :: !(IORef Bool)
}
receive :: Connection -> IO Message
receive conn = do
mbMsg <- connectionParse conn
case mbMsg of
Nothing -> throwIO ConnectionClosed
Just msg -> return msg
receiveDataMessage :: Connection -> IO DataMessage
receiveDataMessage conn = do
msg <- receive conn
case msg of
DataMessage _ _ _ am -> return am
ControlMessage cm -> case cm of
Close i closeMsg -> do
hasSentClose <- readIORef $ connectionSentClose conn
unless hasSentClose $ send conn msg
throwIO $ CloseRequest i closeMsg
Pong _ -> do
connectionOnPong (connectionOptions conn)
receiveDataMessage conn
Ping pl -> do
send conn (ControlMessage (Pong pl))
receiveDataMessage conn
receiveData :: WebSocketsData a => Connection -> IO a
receiveData conn = fromDataMessage <$> receiveDataMessage conn
send :: Connection -> Message -> IO ()
send conn = sendAll conn . return
sendAll :: Connection -> [Message] -> IO ()
sendAll _ [] = return ()
sendAll conn msgs = do
when (any isCloseMessage msgs) $
writeIORef (connectionSentClose conn) True
connectionWrite conn msgs
where
isCloseMessage (ControlMessage (Close _ _)) = True
isCloseMessage _ = False
sendDataMessage :: Connection -> DataMessage -> IO ()
sendDataMessage conn = sendDataMessages conn . return
sendDataMessages :: Connection -> [DataMessage] -> IO ()
sendDataMessages conn = sendAll conn . map (DataMessage False False False)
sendTextData :: WebSocketsData a => Connection -> a -> IO ()
sendTextData conn = sendTextDatas conn . return
sendTextDatas :: WebSocketsData a => Connection -> [a] -> IO ()
sendTextDatas conn =
sendDataMessages conn .
map (\x -> Text (toLazyByteString x) Nothing)
sendBinaryData :: WebSocketsData a => Connection -> a -> IO ()
sendBinaryData conn = sendBinaryDatas conn . return
sendBinaryDatas :: WebSocketsData a => Connection -> [a] -> IO ()
sendBinaryDatas conn = sendDataMessages conn . map (Binary . toLazyByteString)
sendClose :: WebSocketsData a => Connection -> a -> IO ()
sendClose conn = sendCloseCode conn 1000
sendCloseCode :: WebSocketsData a => Connection -> Word16 -> a -> IO ()
sendCloseCode conn code =
send conn . ControlMessage . Close code . toLazyByteString
sendPing :: WebSocketsData a => Connection -> a -> IO ()
sendPing conn = send conn . ControlMessage . Ping . toLazyByteString
forkPingThread :: Connection -> Int -> IO ()
forkPingThread conn n
| n <= 0 = return ()
| otherwise = do
_ <- forkIO (ignore `handle` go 1)
return ()
where
go :: Int -> IO ()
go i = do
threadDelay (n * 1000 * 1000)
sendPing conn (T.pack $ show i)
go (i + 1)
ignore e = case fromException e of
Just async -> throwIO (async :: AsyncException)
Nothing -> return ()