{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
module Network.HTTP2.Client.RawConnection (
RawHttp2Connection (..)
, newRawHttp2Connection
, newRawHttp2ConnectionSocket
) where
import Control.Monad (forever, when)
import Control.Concurrent.Async.Lifted (Async, async, cancel, pollSTM)
import Control.Concurrent.STM (STM, atomically, check, orElse, retry, throwSTM)
import Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO, readTVar, writeTVar)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import Data.ByteString.Lazy (fromChunks)
import Data.Monoid ((<>))
import qualified Network.HTTP2 as HTTP2
import Network.Socket hiding (recv)
import Network.Socket.ByteString
import qualified Network.TLS as TLS
import Network.HTTP2.Client.Exceptions
data RawHttp2Connection = RawHttp2Connection {
_sendRaw :: [ByteString] -> ClientIO ()
, _nextRaw :: Int -> ClientIO ByteString
, _close :: ClientIO ()
}
newRawHttp2Connection :: HostName
-> PortNumber
-> Maybe TLS.ClientParams
-> ClientIO RawHttp2Connection
newRawHttp2Connection host port mparams = do
let hints = defaultHints { addrFlags = [AI_NUMERICSERV], addrSocketType = Stream }
rSkt <- lift $ do
addr:_ <- getAddrInfo (Just hints) (Just host) (Just $ show port)
skt <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
setSocketOption skt NoDelay 1
connect skt (addrAddress addr)
pure skt
newRawHttp2ConnectionSocket rSkt mparams
newRawHttp2ConnectionSocket
:: Socket
-> Maybe TLS.ClientParams
-> ClientIO RawHttp2Connection
newRawHttp2ConnectionSocket skt mparams = do
conn <- lift $ maybe (plainTextRaw skt) (tlsRaw skt) mparams
_sendRaw conn [HTTP2.connectionPreface]
return conn
plainTextRaw :: Socket -> IO RawHttp2Connection
plainTextRaw skt = do
(b,putRaw) <- startWriteWorker (sendMany skt)
(a,getRaw) <- startReadWorker (recv skt)
let doClose = lift $ cancel a >> cancel b >> close skt
return $ RawHttp2Connection (lift . atomically . putRaw) (lift . atomically . getRaw) doClose
tlsRaw :: Socket -> TLS.ClientParams -> IO RawHttp2Connection
tlsRaw skt params = do
tlsContext <- TLS.contextNew skt (modifyParams params)
TLS.handshake tlsContext
(b,putRaw) <- startWriteWorker (TLS.sendData tlsContext . fromChunks)
(a,getRaw) <- startReadWorker (const $ TLS.recvData tlsContext)
let doClose = lift $ cancel a >> cancel b >> TLS.bye tlsContext >> TLS.contextClose tlsContext
return $ RawHttp2Connection (lift . atomically . putRaw) (lift . atomically . getRaw) doClose
where
modifyParams prms = prms {
TLS.clientHooks = (TLS.clientHooks prms) {
TLS.onSuggestALPN = return $ Just [ "h2", "h2-17" ]
}
}
startWriteWorker
:: ([ByteString] -> IO ())
-> IO (Async (), [ByteString] -> STM ())
startWriteWorker sendChunks = do
outQ <- newTVarIO []
let putRaw chunks = modifyTVar' outQ (\xs -> xs ++ chunks)
b <- async $ writeWorkerLoop outQ sendChunks
return (b, putRaw)
writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop outQ sendChunks = forever $ do
xs <- atomically $ do
chunks <- readTVar outQ
when (null chunks) retry
writeTVar outQ []
return chunks
sendChunks xs
startReadWorker
:: (Int -> IO ByteString)
-> IO (Async (), (Int -> STM ByteString))
startReadWorker get = do
remoteClosed <- newTVarIO False
let onEof = atomically $ writeTVar remoteClosed True
let emptyByteStringOnEof = readTVar remoteClosed >>= check >> pure ""
buf <- newTVarIO ""
a <- async $ readWorkerLoop buf get onEof
return $ (a, \len -> getRawWorker a buf len `orElse` emptyByteStringOnEof)
readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO () -> IO ()
readWorkerLoop buf next onEof = go
where
go = do
dat <- next 4096
if ByteString.null dat
then onEof
else atomically (modifyTVar' buf (\bs -> (bs <> dat))) >> go
getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker a buf amount = do
asyncStatus <- pollSTM a
case asyncStatus of
(Just (Left e)) -> throwSTM e
_ -> return ()
dat <- readTVar buf
if amount > ByteString.length dat
then retry
else do
writeTVar buf (ByteString.drop amount dat)
return $ ByteString.take amount dat