{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.QUIC.Client.Reader (
readerClient,
recvClient,
ConnectionControl (..),
controlConnection,
) where
import Data.List (intersect)
import Network.Socket (getSocketName)
import Network.UDP
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Types
readerClient :: UDPSocket -> Connection -> IO ()
readerClient :: UDPSocket -> Connection -> IO ()
readerClient cs0 :: UDPSocket
cs0@(UDPSocket Socket
s0 SockAddr
_ Bool
_) Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IO ()
wait
IO ()
loop
where
wait :: IO ()
wait = do
Bool
bound <- (SomeException -> IO Bool) -> IO Bool -> IO Bool
forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
E.handleAny (\SomeException
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IO ()
forall (m :: * -> *). MonadIO m => m ()
yield
IO ()
wait
loop :: IO ()
loop = do
Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
Maybe ByteString
mbs <-
Microseconds -> String -> IO ByteString -> IO (Maybe ByteString)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
ito String
"readeClient" (IO ByteString -> IO (Maybe ByteString))
-> IO ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$
UDPSocket -> IO ByteString
recv UDPSocket
cs0
case Maybe ByteString
mbs of
Maybe ByteString
Nothing -> UDPSocket -> IO ()
close UDPSocket
cs0
Just ByteString
bs -> do
TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
[PacketI]
pkts <- ByteString -> IO [PacketI]
decodePackets ByteString
bs
(PacketI -> IO ()) -> [PacketI] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
IO ()
loop
logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
Connection -> VersionNegotiationPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
let myVer :: Version
myVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (CID -> Either CID (ByteString, ByteString)
forall a b. a -> Either a b
Left CID
sCID)
let myVers :: [Version]
myVers = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Version -> Bool) -> Version -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
vers :: [Version]
vers@(Version
ver : [Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
[Version]
_ -> VersionInfo
brokenVersionInfo
ThreadId -> Abort -> IO ()
forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) (Abort -> IO ()) -> Abort -> IO ()
forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
putQ TimeMicrosecond
t (PacketIC CryptPacket
pkt EncryptionLevel
lvl Int
siz) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
Connection -> RetryPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn ((AuthCIDs -> AuthCIDs) -> IO ())
-> (AuthCIDs -> AuthCIDs) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth{retrySrcCID = Just sCID}
Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) IO (Seq PlainPacket) -> (Seq PlainPacket -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (PlainPacket -> IO ()) -> Seq PlainPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
where
put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0, ByteString
tag)) = do
CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
Version
ver <- Connection -> IO Version
getVersion Connection
conn
let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)
recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ
data ConnectionControl
= ChangeServerCID
| ChangeClientCID
| NATRebinding
| ActiveMigration
deriving (ConnectionControl -> ConnectionControl -> Bool
(ConnectionControl -> ConnectionControl -> Bool)
-> (ConnectionControl -> ConnectionControl -> Bool)
-> Eq ConnectionControl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
/= :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
(Int -> ConnectionControl -> ShowS)
-> (ConnectionControl -> String)
-> ([ConnectionControl] -> ShowS)
-> Show ConnectionControl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionControl -> ShowS
showsPrec :: Int -> ConnectionControl -> ShowS
$cshow :: ConnectionControl -> String
show :: ConnectionControl -> String
$cshowList :: [ConnectionControl] -> ShowS
showList :: [ConnectionControl] -> ShowS
Show)
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
| Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn = do
Connection -> IO ()
waitEstablished Connection
conn
Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
| Bool
otherwise = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 1" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn
case Maybe CIDInfo
mn of
Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
Int
x <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 2" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn
case Maybe CIDInfo
mn of
Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Maybe CIDInfo
mcidinfo -> do
Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
UDPSocket
cs0 <- Connection -> IO UDPSocket
getSocket Connection
conn
UDPSocket
cs <- UDPSocket -> IO UDPSocket
natRebinding UDPSocket
cs0
UDPSocket
cs0' <- Connection -> UDPSocket -> IO UDPSocket
setSocket Connection
conn UDPSocket
cs
let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
cs Connection
conn
IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ UDPSocket -> IO ()
close UDPSocket
cs0'