{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
module Network.HTTP2.Client.RawConnection (
RawHttp2Connection (..)
, newRawHttp2Connection
, newRawHttp2ConnectionUnix
, newRawHttp2ConnectionSocket
) where
import Control.Monad (forever, when)
import Control.Concurrent.Async.Lifted (Async, async, cancel, pollSTM)
import Control.Concurrent.STM (STM, atomically, check, orElse, retry, throwSTM)
import Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO, readTVar, writeTVar)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import Data.ByteString.Lazy (fromChunks)
import Data.Monoid ((<>))
import qualified Network.HTTP2 as HTTP2
import Network.Socket hiding (recv)
import Network.Socket.ByteString
import qualified Network.TLS as TLS
import Network.HTTP2.Client.Exceptions
data RawHttp2Connection = RawHttp2Connection {
RawHttp2Connection -> [ByteString] -> ClientIO ()
_sendRaw :: [ByteString] -> ClientIO ()
, :: Int -> ClientIO ByteString
, RawHttp2Connection -> ClientIO ()
_close :: ClientIO ()
}
newRawHttp2Connection :: HostName
-> PortNumber
-> Maybe TLS.ClientParams
-> ClientIO RawHttp2Connection
newRawHttp2Connection :: HostName
-> PortNumber -> Maybe ClientParams -> ClientIO RawHttp2Connection
newRawHttp2Connection HostName
host PortNumber
port Maybe ClientParams
mparams = do
let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_NUMERICSERV], addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
Socket
rSkt <- IO Socket -> ExceptT ClientError IO Socket
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Socket -> ExceptT ClientError IO Socket)
-> IO Socket -> ExceptT ClientError IO Socket
forall a b. (a -> b) -> a -> b
$ do
AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)
Socket
skt <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)
Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
skt SocketOption
NoDelay Int
1
Socket -> SockAddr -> IO ()
connect Socket
skt (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)
Socket -> IO Socket
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
skt
Socket -> Maybe ClientParams -> ClientIO RawHttp2Connection
newRawHttp2ConnectionSocket Socket
rSkt Maybe ClientParams
mparams
newRawHttp2ConnectionUnix :: String
-> Maybe TLS.ClientParams
-> ClientIO RawHttp2Connection
newRawHttp2ConnectionUnix :: HostName -> Maybe ClientParams -> ClientIO RawHttp2Connection
newRawHttp2ConnectionUnix HostName
path Maybe ClientParams
mparams = do
Socket
rSkt <- IO Socket -> ExceptT ClientError IO Socket
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Socket -> ExceptT ClientError IO Socket)
-> IO Socket -> ExceptT ClientError IO Socket
forall a b. (a -> b) -> a -> b
$ do
Socket
skt <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_UNIX SocketType
Stream ProtocolNumber
0
Socket -> SockAddr -> IO ()
connect Socket
skt (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName -> SockAddr
SockAddrUnix HostName
path
Socket -> IO Socket
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
skt
Socket -> Maybe ClientParams -> ClientIO RawHttp2Connection
newRawHttp2ConnectionSocket Socket
rSkt Maybe ClientParams
mparams
newRawHttp2ConnectionSocket
:: Socket
-> Maybe TLS.ClientParams
-> ClientIO RawHttp2Connection
newRawHttp2ConnectionSocket :: Socket -> Maybe ClientParams -> ClientIO RawHttp2Connection
newRawHttp2ConnectionSocket Socket
skt Maybe ClientParams
mparams = do
RawHttp2Connection
conn <- IO RawHttp2Connection -> ClientIO RawHttp2Connection
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO RawHttp2Connection -> ClientIO RawHttp2Connection)
-> IO RawHttp2Connection -> ClientIO RawHttp2Connection
forall a b. (a -> b) -> a -> b
$ IO RawHttp2Connection
-> (ClientParams -> IO RawHttp2Connection)
-> Maybe ClientParams
-> IO RawHttp2Connection
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Socket -> IO RawHttp2Connection
plainTextRaw Socket
skt) (Socket -> ClientParams -> IO RawHttp2Connection
tlsRaw Socket
skt) Maybe ClientParams
mparams
RawHttp2Connection -> [ByteString] -> ClientIO ()
_sendRaw RawHttp2Connection
conn [ByteString
HTTP2.connectionPreface]
RawHttp2Connection -> ClientIO RawHttp2Connection
forall (m :: * -> *) a. Monad m => a -> m a
return RawHttp2Connection
conn
plainTextRaw :: Socket -> IO RawHttp2Connection
Socket
skt = do
(Async ()
b,[ByteString] -> STM ()
putRaw) <- ([ByteString] -> IO ()) -> IO (Async (), [ByteString] -> STM ())
startWriteWorker (Socket -> [ByteString] -> IO ()
sendMany Socket
skt)
(Async ()
a,Int -> STM ByteString
getRaw) <- (Int -> IO ByteString) -> IO (Async (), Int -> STM ByteString)
startReadWorker (Socket -> Int -> IO ByteString
recv Socket
skt)
let doClose :: ClientIO ()
doClose = IO () -> ClientIO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ClientIO ()) -> IO () -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ Async () -> IO ()
forall (m :: * -> *) a. MonadBase IO m => Async a -> m ()
cancel Async ()
a IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async () -> IO ()
forall (m :: * -> *) a. MonadBase IO m => Async a -> m ()
cancel Async ()
b IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO ()
close Socket
skt
RawHttp2Connection -> IO RawHttp2Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (RawHttp2Connection -> IO RawHttp2Connection)
-> RawHttp2Connection -> IO RawHttp2Connection
forall a b. (a -> b) -> a -> b
$ ([ByteString] -> ClientIO ())
-> (Int -> ClientIO ByteString)
-> ClientIO ()
-> RawHttp2Connection
RawHttp2Connection (IO () -> ClientIO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ClientIO ())
-> ([ByteString] -> IO ()) -> [ByteString] -> ClientIO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ())
-> ([ByteString] -> STM ()) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> STM ()
putRaw) (IO ByteString -> ClientIO ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO ByteString -> ClientIO ByteString)
-> (Int -> IO ByteString) -> Int -> ClientIO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM ByteString -> IO ByteString
forall a. STM a -> IO a
atomically (STM ByteString -> IO ByteString)
-> (Int -> STM ByteString) -> Int -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> STM ByteString
getRaw) ClientIO ()
doClose
tlsRaw :: Socket -> TLS.ClientParams -> IO RawHttp2Connection
tlsRaw :: Socket -> ClientParams -> IO RawHttp2Connection
tlsRaw Socket
skt ClientParams
params = do
Context
tlsContext <- Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
skt (ClientParams -> ClientParams
modifyParams ClientParams
params)
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
tlsContext
(Async ()
b,[ByteString] -> STM ()
putRaw) <- ([ByteString] -> IO ()) -> IO (Async (), [ByteString] -> STM ())
startWriteWorker (Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
tlsContext (ByteString -> IO ())
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
fromChunks)
(Async ()
a,Int -> STM ByteString
getRaw) <- (Int -> IO ByteString) -> IO (Async (), Int -> STM ByteString)
startReadWorker (IO ByteString -> Int -> IO ByteString
forall a b. a -> b -> a
const (IO ByteString -> Int -> IO ByteString)
-> IO ByteString -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
tlsContext)
let doClose :: ClientIO ()
doClose = IO () -> ClientIO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ClientIO ()) -> IO () -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ Async () -> IO ()
forall (m :: * -> *) a. MonadBase IO m => Async a -> m ()
cancel Async ()
a IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async () -> IO ()
forall (m :: * -> *) a. MonadBase IO m => Async a -> m ()
cancel Async ()
b IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
tlsContext IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
tlsContext
RawHttp2Connection -> IO RawHttp2Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (RawHttp2Connection -> IO RawHttp2Connection)
-> RawHttp2Connection -> IO RawHttp2Connection
forall a b. (a -> b) -> a -> b
$ ([ByteString] -> ClientIO ())
-> (Int -> ClientIO ByteString)
-> ClientIO ()
-> RawHttp2Connection
RawHttp2Connection (IO () -> ClientIO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ClientIO ())
-> ([ByteString] -> IO ()) -> [ByteString] -> ClientIO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ())
-> ([ByteString] -> STM ()) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> STM ()
putRaw) (IO ByteString -> ClientIO ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO ByteString -> ClientIO ByteString)
-> (Int -> IO ByteString) -> Int -> ClientIO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM ByteString -> IO ByteString
forall a. STM a -> IO a
atomically (STM ByteString -> IO ByteString)
-> (Int -> STM ByteString) -> Int -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> STM ByteString
getRaw) ClientIO ()
doClose
where
modifyParams :: ClientParams -> ClientParams
modifyParams ClientParams
prms = ClientParams
prms {
clientHooks :: ClientHooks
TLS.clientHooks = (ClientParams -> ClientHooks
TLS.clientHooks ClientParams
prms) {
onSuggestALPN :: IO (Maybe [ByteString])
TLS.onSuggestALPN = Maybe [ByteString] -> IO (Maybe [ByteString])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [ByteString] -> IO (Maybe [ByteString]))
-> Maybe [ByteString] -> IO (Maybe [ByteString])
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Maybe [ByteString]
forall a. a -> Maybe a
Just [ ByteString
"h2", ByteString
"h2-17" ]
}
}
startWriteWorker
:: ([ByteString] -> IO ())
-> IO (Async (), [ByteString] -> STM ())
startWriteWorker :: ([ByteString] -> IO ()) -> IO (Async (), [ByteString] -> STM ())
startWriteWorker [ByteString] -> IO ()
sendChunks = do
TVar [ByteString]
outQ <- [ByteString] -> IO (TVar [ByteString])
forall a. a -> IO (TVar a)
newTVarIO []
let putRaw :: [ByteString] -> STM ()
putRaw [ByteString]
chunks = TVar [ByteString] -> ([ByteString] -> [ByteString]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [ByteString]
outQ (\[ByteString]
xs -> [ByteString]
xs [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ByteString]
chunks)
Async ()
b <- IO () -> IO (Async (StM IO ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> m (Async (StM m a))
async (IO () -> IO (Async (StM IO ())))
-> IO () -> IO (Async (StM IO ()))
forall a b. (a -> b) -> a -> b
$ TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop TVar [ByteString]
outQ [ByteString] -> IO ()
sendChunks
(Async (), [ByteString] -> STM ())
-> IO (Async (), [ByteString] -> STM ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Async ()
b, [ByteString] -> STM ()
putRaw)
writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop TVar [ByteString]
outQ [ByteString] -> IO ()
sendChunks = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
[ByteString]
xs <- STM [ByteString] -> IO [ByteString]
forall a. STM a -> IO a
atomically (STM [ByteString] -> IO [ByteString])
-> STM [ByteString] -> IO [ByteString]
forall a b. (a -> b) -> a -> b
$ do
[ByteString]
chunks <- TVar [ByteString] -> STM [ByteString]
forall a. TVar a -> STM a
readTVar TVar [ByteString]
outQ
Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ByteString]
chunks) STM ()
forall a. STM a
retry
TVar [ByteString] -> [ByteString] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [ByteString]
outQ []
[ByteString] -> STM [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString]
chunks
[ByteString] -> IO ()
sendChunks [ByteString]
xs
startReadWorker
:: (Int -> IO ByteString)
-> IO (Async (), (Int -> STM ByteString))
startReadWorker :: (Int -> IO ByteString) -> IO (Async (), Int -> STM ByteString)
startReadWorker Int -> IO ByteString
get = do
TVar Bool
remoteClosed <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
let onEof :: IO ()
onEof = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
remoteClosed Bool
True
let emptyByteStringOnEof :: STM ByteString
emptyByteStringOnEof = TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
remoteClosed STM Bool -> (Bool -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check STM () -> STM ByteString -> STM ByteString
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> STM ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
""
TVar ByteString
buf <- ByteString -> IO (TVar ByteString)
forall a. a -> IO (TVar a)
newTVarIO ByteString
""
Async ()
a <- IO () -> IO (Async (StM IO ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> m (Async (StM m a))
async (IO () -> IO (Async (StM IO ())))
-> IO () -> IO (Async (StM IO ()))
forall a b. (a -> b) -> a -> b
$ TVar ByteString -> (Int -> IO ByteString) -> IO () -> IO ()
readWorkerLoop TVar ByteString
buf Int -> IO ByteString
get IO ()
onEof
(Async (), Int -> STM ByteString)
-> IO (Async (), Int -> STM ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Async (), Int -> STM ByteString)
-> IO (Async (), Int -> STM ByteString))
-> (Async (), Int -> STM ByteString)
-> IO (Async (), Int -> STM ByteString)
forall a b. (a -> b) -> a -> b
$ (Async ()
a, \Int
len -> Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker Async ()
a TVar ByteString
buf Int
len STM ByteString -> STM ByteString -> STM ByteString
forall a. STM a -> STM a -> STM a
`orElse` STM ByteString
emptyByteStringOnEof)
readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO () -> IO ()
readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO () -> IO ()
readWorkerLoop TVar ByteString
buf Int -> IO ByteString
next IO ()
onEof = IO ()
go
where
go :: IO ()
go = do
ByteString
dat <- Int -> IO ByteString
next Int
4096
if ByteString -> Bool
ByteString.null ByteString
dat
then IO ()
onEof
else STM () -> IO ()
forall a. STM a -> IO a
atomically (TVar ByteString -> (ByteString -> ByteString) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar ByteString
buf (\ByteString
bs -> (ByteString
bs ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
dat))) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
go
getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker Async ()
a TVar ByteString
buf Int
amount = do
Maybe (Either SomeException ())
asyncStatus <- Async () -> STM (Maybe (Either SomeException ()))
forall a. Async a -> STM (Maybe (Either SomeException a))
pollSTM Async ()
a
case Maybe (Either SomeException ())
asyncStatus of
(Just (Left SomeException
e)) -> SomeException -> STM ()
forall e a. Exception e => e -> STM a
throwSTM SomeException
e
Maybe (Either SomeException ())
_ -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ByteString
dat <- TVar ByteString -> STM ByteString
forall a. TVar a -> STM a
readTVar TVar ByteString
buf
if Int
amount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> ByteString -> Int
ByteString.length ByteString
dat
then STM ByteString
forall a. STM a
retry
else do
TVar ByteString -> ByteString -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ByteString
buf (Int -> ByteString -> ByteString
ByteString.drop Int
amount ByteString
dat)
ByteString -> STM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> STM ByteString) -> ByteString -> STM ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
ByteString.take Int
amount ByteString
dat