{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
module Network.WebSockets.Extensions.PermessageDeflate
( defaultPermessageDeflate
, PermessageDeflate(..)
, negotiateDeflate
, makeMessageInflater
, makeMessageDeflater
) where
import Control.Applicative ((<$>))
import Control.Exception (throwIO)
import Control.Monad (foldM, unless)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as BL8
import qualified Data.ByteString.Lazy.Internal as BL
import Data.Int (Int64)
import Data.Monoid
import qualified Data.Streaming.Zlib as Zlib
import Network.WebSockets.Connection.Options
import Network.WebSockets.Extensions
import Network.WebSockets.Extensions.Description
import Network.WebSockets.Http
import Network.WebSockets.Types
import Prelude
import Text.Read (readMaybe)
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate {..} = ExtensionDescription
{ extName = "permessage-deflate"
, extParams =
[("server_no_context_takeover", Nothing) | serverNoContextTakeover] ++
[("client_no_context_takeover", Nothing) | clientNoContextTakeover] ++
[("server_max_window_bits", param serverMaxWindowBits) | serverMaxWindowBits /= 15] ++
[("client_max_window_bits", param clientMaxWindowBits) | clientMaxWindowBits /= 15]
}
where
param = Just . B8.pack . show
toHeaders :: PermessageDeflate -> Headers
toHeaders pmd =
[ ( "Sec-WebSocket-Extensions"
, encodeExtensionDescriptions [toExtensionDescription pmd]
)
]
negotiateDeflate
:: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate messageLimit pmd0 exts0 = do
(headers, pmd1) <- negotiateDeflateOpts exts0 pmd0
return Extension
{ extHeaders = headers
, extParse = \parseRaw -> do
inflate <- makeMessageInflater messageLimit pmd1
return $ do
msg <- parseRaw
case msg of
Nothing -> return Nothing
Just m -> fmap Just (inflate m)
, extWrite = \writeRaw -> do
deflate <- makeMessageDeflater pmd1
return $ \msgs ->
mapM deflate msgs >>= writeRaw
}
where
negotiateDeflateOpts
:: ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts (ext : _) (Just x)
| extName ext == "x-webkit-deflate-frame" = Right
([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], Just x)
negotiateDeflateOpts (ext : _) (Just x)
| extName ext == "permessage-deflate" = do
x' <- foldM setParam x (extParams ext)
Right (toHeaders x', Just x')
negotiateDeflateOpts (_ : exts) (Just x) =
negotiateDeflateOpts exts (Just x)
negotiateDeflateOpts _ _ = Right ([], Nothing)
setParam
:: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate
setParam pmd ("server_no_context_takeover", _) =
Right pmd {serverNoContextTakeover = True}
setParam pmd ("client_no_context_takeover", _) =
Right pmd {clientNoContextTakeover = True}
setParam pmd ("server_max_window_bits", Nothing) =
Right pmd {serverMaxWindowBits = 15}
setParam pmd ("server_max_window_bits", Just param) = do
w <- parseWindow param
Right pmd {serverMaxWindowBits = w}
setParam pmd ("client_max_window_bits", Nothing) = do
Right pmd {clientMaxWindowBits = 15}
setParam pmd ("client_max_window_bits", Just param) = do
w <- parseWindow param
Right pmd {clientMaxWindowBits = w}
setParam pmd (_, _) = Right pmd
parseWindow :: B.ByteString -> Either String Int
parseWindow bs8 = case readMaybe (B8.unpack bs8) of
Just w
| w >= 8 && w <= 15 -> Right w
| otherwise -> Left $ "Window out of bounds: " ++ show w
Nothing -> Left $ "Can't parse window: " ++ show bs8
fixWindowBits :: Int -> Int
fixWindowBits n
| n < 9 = 9
| n > 15 = 15
| otherwise = n
appTailL :: BL.ByteString
appTailL = BL.pack [0x00,0x00,0xff,0xff]
maybeStrip :: BL.ByteString -> BL.ByteString
maybeStrip x | appTailL `BL.isSuffixOf` x = BL.take (BL.length x - 4) x
maybeStrip x = x
rejectExtensions :: Message -> IO Message
rejectExtensions (DataMessage rsv1 rsv2 rsv3 _) | rsv1 || rsv2 || rsv3 =
throwIO $ CloseRequest 1002 "Protocol Error"
rejectExtensions x = return x
makeMessageDeflater
:: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Nothing = return rejectExtensions
makeMessageDeflater (Just pmd)
| serverNoContextTakeover pmd = do
return $ \msg -> do
ptr <- initDeflate pmd
deflateMessageWith (deflateBody ptr) msg
| otherwise = do
ptr <- initDeflate pmd
return $ \msg ->
deflateMessageWith (deflateBody ptr) msg
where
initDeflate :: PermessageDeflate -> IO Zlib.Deflate
initDeflate PermessageDeflate {..} =
Zlib.initDeflate
pdCompressionLevel
(Zlib.WindowBits (- (fixWindowBits serverMaxWindowBits)))
deflateMessageWith
:: (BL.ByteString -> IO BL.ByteString)
-> Message -> IO Message
deflateMessageWith deflater (DataMessage False False False (Text x _)) = do
x' <- deflater x
return (DataMessage True False False (Text x' Nothing))
deflateMessageWith deflater (DataMessage False False False (Binary x)) = do
x' <- deflater x
return (DataMessage True False False (Binary x'))
deflateMessageWith _ x = return x
deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString
deflateBody ptr = fmap maybeStrip . go . BL.toChunks
where
go [] =
dePopper (Zlib.flushDeflate ptr)
go (c : cs) = do
chunk <- Zlib.feedDeflate ptr c >>= dePopper
(chunk <>) <$> go cs
dePopper :: Zlib.Popper -> IO BL.ByteString
dePopper p = p >>= \res -> case res of
Zlib.PRDone -> return BL.empty
Zlib.PRNext c -> BL.chunk c <$> dePopper p
Zlib.PRError x -> throwIO $ CloseRequest 1002 (BL8.pack (show x))
makeMessageInflater
:: SizeLimit -> Maybe PermessageDeflate
-> IO (Message -> IO Message)
makeMessageInflater _ Nothing = return rejectExtensions
makeMessageInflater messageLimit (Just pmd)
| clientNoContextTakeover pmd =
return $ \msg -> do
ptr <- initInflate pmd
inflateMessageWith (inflateBody ptr) msg
| otherwise = do
ptr <- initInflate pmd
return $ \msg ->
inflateMessageWith (inflateBody ptr) msg
where
initInflate :: PermessageDeflate -> IO Zlib.Inflate
initInflate PermessageDeflate {..} =
Zlib.initInflate
(Zlib.WindowBits (- (fixWindowBits clientMaxWindowBits)))
inflateMessageWith
:: (BL.ByteString -> IO BL.ByteString)
-> Message -> IO Message
inflateMessageWith inflater (DataMessage True a b (Text x _)) = do
x' <- inflater x
return (DataMessage False a b (Text x' Nothing))
inflateMessageWith inflater (DataMessage True a b (Binary x)) = do
x' <- inflater x
return (DataMessage False a b (Binary x'))
inflateMessageWith _ x = return x
inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString
inflateBody ptr =
go 0 . BL.toChunks . (<> appTailL)
where
go :: Int64 -> [B.ByteString] -> IO BL.ByteString
go size0 [] = do
chunk <- Zlib.flushInflate ptr
checkSize (fromIntegral (B.length chunk) + size0)
return (BL.fromStrict chunk)
go size0 (c : cs) = do
chunk <- Zlib.feedInflate ptr c >>= dePopper
let size1 = size0 + BL.length chunk
checkSize size1
(chunk <>) <$> go size1 cs
checkSize :: Int64 -> IO ()
checkSize size = unless (atMostSizeLimit size messageLimit) $ throwIO $
ParseException $ "Message of size " ++ show size ++ " exceeded limit"