{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -Wall #-}
module Network.Http.Client.WebSocket
(
WSFrameHdr(..)
, wsFrameHdrSize
, wsFrameHdrToBuilder
, WSOpcode(..)
, WSOpcodeReserved(..)
, wsIsDataFrame
, writeWSFrame
, sendWSFragData
, readWSFrame
, receiveWSFrame
, wsUpgradeConnection
, SecWebSocketKey
, wsKeyToAcceptB64
, secWebSocketKeyFromB64
, secWebSocketKeyToB64
, secWebSocketKeyFromWords
, WsException(..)
) where
import Blaze.ByteString.Builder (Builder)
import qualified Blaze.ByteString.Builder as Builder
import Control.Exception
import Control.Monad (unless, when)
import qualified Crypto.Hash.SHA1 as SHA1
import qualified Data.Binary as Bin
import qualified Data.Binary.Get as Bin
import qualified Data.Binary.Put as Bin
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import Data.IORef
import Data.Maybe (isJust)
import Data.Monoid (Monoid (..))
import Data.Typeable (Typeable)
import Data.Word
import Data.XOR (xor32LazyByteString, xor32StrictByteString')
import Network.Http.Client as HC
import qualified Network.Http.Connection as HC
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
data WsException = WsException String
deriving (Typeable,Show)
instance Exception WsException
data WSFrameHdr = WSFrameHdr
{ ws'FIN :: !Bool
, ws'RSV1 :: !Bool
, ws'RSV2 :: !Bool
, ws'RSV3 :: !Bool
, ws'opcode :: !WSOpcode
, ws'length :: !Word64
, ws'mask :: !(Maybe Word32)
} deriving Show
wsFrameHdrSize :: WSFrameHdr -> Int
wsFrameHdrSize WSFrameHdr{ws'mask = Nothing, ws'length}
| ws'length < 126 = 2
| ws'length <= 0xffff = 4
| otherwise = 10
wsFrameHdrSize WSFrameHdr{ws'mask = Just _, ws'length}
| ws'length < 126 = 6
| ws'length <= 0xffff = 8
| otherwise = 14
readWSFrameHdr :: Connection -> IO (Maybe WSFrameHdr)
readWSFrameHdr (HC.Connection { cIn = is }) = do
mchunk <- Streams.read is
case mchunk of
Nothing -> return Nothing
Just chunk -> go $ (if BS.null chunk then id else flip Bin.pushChunk chunk)
$ Bin.runGetIncremental Bin.get
where
go :: Bin.Decoder WSFrameHdr -> IO (Maybe WSFrameHdr)
go (Bin.Fail rest _ err) = do
unless (BS.null rest) $
Streams.unRead rest is
throwIO $ WsException ("readWSFrameHdr: " ++ err)
go partial@(Bin.Partial cont) = do
mchunk <- Streams.read is
case mchunk of
Nothing -> go (cont Nothing)
Just chunk
| BS.null chunk -> go partial
| otherwise -> go (cont (Just chunk))
go (Bin.Done rest _ x) = do
unless (BS.null rest) $
Streams.unRead rest is
return (Just x)
receiveWSFrame :: Connection -> (WSFrameHdr -> InputStream ByteString -> IO a) -> IO (Maybe a)
receiveWSFrame (conn@HC.Connection { cIn = is }) cont = do
mhdr <- readWSFrameHdr conn
case mhdr of
Nothing -> return Nothing
Just hdr
| ws'length hdr == 0 -> do
is' <- Streams.nullInput
Just `fmap` cont hdr is'
| otherwise -> do
is' <- Streams.takeExactly (fromIntegral (ws'length hdr)) is
is'' <- xor32InputStream (maybe 0 id (ws'mask hdr)) is'
res <- cont hdr is''
Streams.skipToEof is'
return $! Just $! res
sendWSFragData :: Connection -> WSFrameHdr -> (OutputStream ByteString -> IO a) -> IO a
sendWSFragData _ hdr0 _
| not (wsIsDataFrame (ws'opcode hdr0))
= throwIO (WsException "sendWSFragData: sending control-frame requested")
| ws'opcode hdr0 == WSOpcode'Continuation
= throwIO (WsException "sendWSFragData: sending continuation frame requested")
sendWSFragData (HC.Connection { cOut = os }) hdr0 cont = do
opcodeRef <- newIORef (ws'opcode hdr0)
let go Nothing = return ()
go (Just chunk)
| BS.null chunk = Streams.write (Just Builder.flush) os
| otherwise = do
let (_,chunk') = xor32StrictByteString' (maybe 0 id $ ws'mask hdr0) chunk
opcode <- readIORef opcodeRef
writeIORef opcodeRef WSOpcode'Continuation
let fraghdr = hdr0 { ws'FIN = False
, ws'length = fromIntegral (BS.length chunk)
, ws'opcode = opcode
}
Streams.write (Just $! wsFrameHdrToBuilder fraghdr `mappend` Builder.fromByteString chunk') os
os' <- Streams.makeOutputStream go
!res <- cont os'
opcode <- readIORef opcodeRef
let final = (hdr0 { ws'FIN = True
, ws'length = 0
, ws'opcode = opcode
, ws'mask = Just 0
})
Streams.write (Just $ wsFrameHdrToBuilder final `mappend` Builder.flush) os
return $! res
writeWSFrame :: Connection -> WSOpcode -> Maybe Word32 -> BL.ByteString -> IO ()
writeWSFrame (HC.Connection { cOut = os }) opcode mmask payload = do
when (not (wsIsDataFrame opcode) && plen >= 126) $
throwIO (WsException "writeWSFrame: over-sized control-frame")
let hdr = wsFrameHdrToBuilder (WSFrameHdr True False False False opcode plen mmask)
dat = case mmask of
Nothing -> Builder.fromLazyByteString payload
Just 0 -> Builder.fromLazyByteString payload
Just msk -> Builder.fromLazyByteString (xor32LazyByteString msk payload)
Streams.write (Just $ hdr `mappend` dat `mappend` Builder.flush) os
where
plen = fromIntegral (BL.length payload)
readWSFrame :: Int -> Connection -> IO (Maybe (WSFrameHdr,ByteString))
readWSFrame maxSize (conn@HC.Connection { cIn = is }) = do
mhdr <- readWSFrameHdr conn
case mhdr of
Nothing -> return Nothing
Just hdr
| ws'length hdr == 0 ->
return $ Just (hdr,BS.empty)
| ws'length hdr >= fromIntegral maxSize ->
throwIO (WsException "readWSFrame: frame larger than maxSize")
| otherwise -> do
dat <- Streams.readExactly (fromIntegral (ws'length hdr)) is
let dat' = case ws'mask hdr of
Nothing -> dat
Just 0 -> dat
Just m -> snd (xor32StrictByteString' m dat)
return $ Just (hdr,dat')
wsFrameHdrToBuilder :: WSFrameHdr -> Builder
wsFrameHdrToBuilder WSFrameHdr{..} = mconcat
[ Builder.fromWord8 $!
(if ws'FIN then 0x80 else 0) .|.
(if ws'RSV1 then 0x40 else 0) .|.
(if ws'RSV2 then 0x20 else 0) .|.
(if ws'RSV3 then 0x10 else 0) .|.
(encodeWSOpcode ws'opcode)
, Builder.fromWord8 $!
(if isJust ws'mask then 0x80 else 0) .|. len7
, case len7 of
126 -> Builder.fromWord16be (fromIntegral ws'length)
127 -> Builder.fromWord64be ws'length
_ -> Data.Monoid.mempty
, maybe mempty Builder.fromWord32be ws'mask
]
where
len7 | ws'length < 126 = fromIntegral ws'length
| ws'length <= 0xffff = 126
| otherwise = 127
instance Bin.Binary WSFrameHdr where
put WSFrameHdr{..} = do
Bin.putWord8 $!
(if ws'FIN then 0x80 else 0) .|.
(if ws'RSV1 then 0x40 else 0) .|.
(if ws'RSV2 then 0x20 else 0) .|.
(if ws'RSV3 then 0x10 else 0) .|.
(encodeWSOpcode ws'opcode)
Bin.putWord8 $!
(if isJust ws'mask then 0x80 else 0) .|. len7
case len7 of
126 -> Bin.putWord16be (fromIntegral ws'length)
127 -> Bin.putWord64be ws'length
_ -> return ()
maybe (return ()) Bin.putWord32be ws'mask
where
len7 | ws'length < 126 = fromIntegral ws'length
| ws'length <= 0xffff = 126
| otherwise = 127
get = do
o0 <- Bin.getWord8
let ws'FIN = testBit o0 7
ws'RSV1 = testBit o0 6
ws'RSV2 = testBit o0 5
ws'RSV3 = testBit o0 4
ws'opcode = decodeWSOpcode o0
when (not ws'FIN && not (wsIsDataFrame ws'opcode)) $
fail "invalid fragmented control-frame"
o1 <- Bin.getWord8
let len7 = o1 .&. 0x7f
msk = o1 >= 0x80
ws'length <- case len7 of
127 -> do
unless (wsIsDataFrame ws'opcode) $ fail "invalid 64-bit extended length (control-frame)"
v <- Bin.getWord64be
unless (v > 0xffff) $ fail "invalid 64-bit extended length (<= 0xffff)"
unless (v < 0x8000000000000000) $ fail "invalid 64-bit extended length (MSB set)"
return v
126 -> do
unless (wsIsDataFrame ws'opcode) $ fail "invalid 16-bit extended length (control-frame)"
v <- Bin.getWord16be
unless (v > 125) $ fail "invalid 16-bit extended length (<= 0x7d)"
return (fromIntegral v)
_ -> return (fromIntegral len7)
ws'mask <- if msk
then Just `fmap` Bin.getWord32be
else return Nothing
return WSFrameHdr{..}
data WSOpcode
= WSOpcode'Continuation
| WSOpcode'Text
| WSOpcode'Binary
| WSOpcode'Close
| WSOpcode'Ping
| WSOpcode'Pong
| WSOpcode'Reserved !WSOpcodeReserved
deriving (Eq,Show)
data WSOpcodeReserved
= WSOpcode'Reserved3
| WSOpcode'Reserved4
| WSOpcode'Reserved5
| WSOpcode'Reserved6
| WSOpcode'Reserved7
| WSOpcode'Reserved11
| WSOpcode'Reserved12
| WSOpcode'Reserved13
| WSOpcode'Reserved14
| WSOpcode'Reserved15
deriving (Eq,Show)
wsIsDataFrame :: WSOpcode -> Bool
wsIsDataFrame x = case x of
WSOpcode'Continuation -> True
WSOpcode'Text -> True
WSOpcode'Binary -> True
WSOpcode'Reserved WSOpcode'Reserved3 -> True
WSOpcode'Reserved WSOpcode'Reserved4 -> True
WSOpcode'Reserved WSOpcode'Reserved5 -> True
WSOpcode'Reserved WSOpcode'Reserved6 -> True
WSOpcode'Reserved WSOpcode'Reserved7 -> True
WSOpcode'Close -> False
WSOpcode'Ping -> False
WSOpcode'Pong -> False
WSOpcode'Reserved WSOpcode'Reserved11 -> False
WSOpcode'Reserved WSOpcode'Reserved12 -> False
WSOpcode'Reserved WSOpcode'Reserved13 -> False
WSOpcode'Reserved WSOpcode'Reserved14 -> False
WSOpcode'Reserved WSOpcode'Reserved15 -> False
decodeWSOpcode :: Word8 -> WSOpcode
decodeWSOpcode x = case x .&. 0xf of
0x0 -> WSOpcode'Continuation
0x1 -> WSOpcode'Text
0x2 -> WSOpcode'Binary
0x3 -> WSOpcode'Reserved WSOpcode'Reserved3
0x4 -> WSOpcode'Reserved WSOpcode'Reserved4
0x5 -> WSOpcode'Reserved WSOpcode'Reserved5
0x6 -> WSOpcode'Reserved WSOpcode'Reserved6
0x7 -> WSOpcode'Reserved WSOpcode'Reserved7
0x8 -> WSOpcode'Close
0x9 -> WSOpcode'Ping
0xA -> WSOpcode'Pong
0xB -> WSOpcode'Reserved WSOpcode'Reserved11
0xC -> WSOpcode'Reserved WSOpcode'Reserved12
0xD -> WSOpcode'Reserved WSOpcode'Reserved13
0xE -> WSOpcode'Reserved WSOpcode'Reserved14
0xF -> WSOpcode'Reserved WSOpcode'Reserved15
_ -> undefined
encodeWSOpcode :: WSOpcode -> Word8
encodeWSOpcode x = case x of
WSOpcode'Continuation -> 0x0
WSOpcode'Text -> 0x1
WSOpcode'Binary -> 0x2
WSOpcode'Reserved WSOpcode'Reserved3 -> 0x3
WSOpcode'Reserved WSOpcode'Reserved4 -> 0x4
WSOpcode'Reserved WSOpcode'Reserved5 -> 0x5
WSOpcode'Reserved WSOpcode'Reserved6 -> 0x6
WSOpcode'Reserved WSOpcode'Reserved7 -> 0x7
WSOpcode'Close -> 0x8
WSOpcode'Ping -> 0x9
WSOpcode'Pong -> 0xA
WSOpcode'Reserved WSOpcode'Reserved11 -> 0xB
WSOpcode'Reserved WSOpcode'Reserved12 -> 0xC
WSOpcode'Reserved WSOpcode'Reserved13 -> 0xD
WSOpcode'Reserved WSOpcode'Reserved14 -> 0xE
WSOpcode'Reserved WSOpcode'Reserved15 -> 0xF
wsKeyToAcceptB64 :: SecWebSocketKey -> ByteString
wsKeyToAcceptB64 key = B64.encode (SHA1.hash (secWebSocketKeyToB64 key `BS.append` rfc6455Guid))
where
rfc6455Guid :: ByteString
rfc6455Guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
newtype SecWebSocketKey = WSKey ByteString deriving (Eq,Ord,Show)
secWebSocketKeyFromB64 :: ByteString -> Maybe SecWebSocketKey
secWebSocketKeyFromB64 key
| BS.length key' /= 24 = Nothing
| Left _ <- B64.decode key' = Nothing
| otherwise = Just $! WSKey key'
where
key' = fst (BS.spanEnd isOWS (BS.dropWhile isOWS key))
isOWS :: Word8 -> Bool
isOWS 0x09 = True
isOWS 0x20 = True
isOWS _ = False
secWebSocketKeyToB64 :: SecWebSocketKey -> ByteString
secWebSocketKeyToB64 (WSKey bs) = bs
secWebSocketKeyFromWords :: Word64 -> Word64 -> SecWebSocketKey
secWebSocketKeyFromWords h l = WSKey (B64.encode key)
where
key = BS.concat $ BL.toChunks $ Bin.runPut (Bin.putWord64be h >> Bin.putWord64be l)
xor32InputStream :: Word32 -> InputStream ByteString -> IO (InputStream ByteString)
xor32InputStream 0 is = return is
xor32InputStream msk0 is = do
mskref <- newIORef msk0
let go = do
mchunk <- Streams.read is
case mchunk of
Nothing -> return Nothing
Just chunk -> do
msk <- readIORef mskref
let (msk',chunk') = xor32StrictByteString' msk chunk
writeIORef mskref msk'
return $! Just $! chunk'
Streams.makeInputStream go
wsUpgradeConnection :: Connection
-> ByteString
-> RequestBuilder α
-> SecWebSocketKey
-> (Response -> InputStream ByteString -> IO b)
-> (Response -> Connection -> IO b)
-> IO b
wsUpgradeConnection conn resource rqmod wskey failedToUpgrade success = do
let rqToWS = HC.buildRequest1 $ do
HC.http HC.GET resource
HC.setHeader "upgrade" "websocket"
HC.setHeader "connection" "Upgrade"
HC.setHeader "sec-websocket-version" "13"
HC.setHeader "sec-websocket-key" (secWebSocketKeyToB64 wskey)
rqmod
HC.sendRequest conn rqToWS HC.emptyBody
HC.receiveUpgradeResponse conn failedToUpgrade $ \resp _is _os -> do
case CI.mk `fmap` HC.getHeader resp "connection" of
Nothing -> abort "missing 'connection' header"
Just "upgrade" -> return ()
Just _ -> abort "'connection' header has non-'upgrade' value"
case CI.mk `fmap` HC.getHeader resp "upgrade" of
Nothing -> abort "missing 'upgrade' header"
Just "websocket" -> return ()
Just _ -> abort "'upgrade' header has non-'websocket' value"
case HC.getHeader resp "sec-websocket-accept" of
Nothing -> abort "missing 'sec-websocket-accept' header"
Just wsacc
| wsacc /= wsKeyToAcceptB64 wskey -> abort "sec-websocket-accept header mismatch"
| otherwise -> return ()
success resp conn
where
abort msg = throwIO (WsException ("wsUpgradeConnection: "++msg))