{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Reader (
    readerClient,
    recvClient,
    ConnectionControl (..),
    controlConnection,
) where

import Data.List (intersect)
import Network.Socket (getSocketName)
import Network.UDP
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Types

-- | readerClient dies when the socket is closed.
readerClient :: UDPSocket -> Connection -> IO ()
readerClient :: UDPSocket -> Connection -> IO ()
readerClient cs0 :: UDPSocket
cs0@(UDPSocket Socket
s0 SockAddr
_ Bool
_) Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    IO ()
wait
    IO ()
loop
  where
    wait :: IO ()
wait = do
        Bool
bound <- (SomeException -> IO Bool) -> IO Bool -> IO Bool
forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
E.handleAny (\SomeException
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
            SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            IO ()
forall (m :: * -> *). MonadIO m => m ()
yield
            IO ()
wait
    loop :: IO ()
loop = do
        Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
        Maybe ByteString
mbs <-
            Microseconds -> String -> IO ByteString -> IO (Maybe ByteString)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
ito String
"readeClient" (IO ByteString -> IO (Maybe ByteString))
-> IO ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$
                UDPSocket -> IO ByteString
recv UDPSocket
cs0
        case Maybe ByteString
mbs of
            Maybe ByteString
Nothing -> UDPSocket -> IO ()
close UDPSocket
cs0
            Just ByteString
bs -> do
                TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
                [PacketI]
pkts <- ByteString -> IO [PacketI]
decodePackets ByteString
bs
                (PacketI -> IO ()) -> [PacketI] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
                IO ()
loop
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
    putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
        Connection -> VersionNegotiationPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
        VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        let myVer :: Version
myVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
            myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
        -- ignoring VN if the original version is included.
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (CID -> Either CID (ByteString, ByteString)
forall a b. a -> Either a b
Left CID
sCID)
            let myVers :: [Version]
myVers = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Version -> Bool) -> Version -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
                nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
                    vers :: [Version]
vers@(Version
ver : [Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
                    [Version]
_ -> VersionInfo
brokenVersionInfo
            ThreadId -> Abort -> IO ()
forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) (Abort -> IO ()) -> Abort -> IO ()
forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
    putQ TimeMicrosecond
t (PacketIC CryptPacket
pkt EncryptionLevel
lvl Int
siz) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
    putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
        Connection -> RetryPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
        Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
            Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn ((AuthCIDs -> AuthCIDs) -> IO ())
-> (AuthCIDs -> AuthCIDs) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth{retrySrcCID = Just sCID}
            Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
            Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
            Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
            LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) IO (Seq PlainPacket) -> (Seq PlainPacket -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (PlainPacket -> IO ()) -> Seq PlainPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
      where
        put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt

checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0, ByteString
tag)) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)

recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ

----------------------------------------------------------------

-- | How to control a connection.
data ConnectionControl
    = ChangeServerCID
    | ChangeClientCID
    | NATRebinding
    | ActiveMigration
    deriving (ConnectionControl -> ConnectionControl -> Bool
(ConnectionControl -> ConnectionControl -> Bool)
-> (ConnectionControl -> ConnectionControl -> Bool)
-> Eq ConnectionControl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
/= :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
(Int -> ConnectionControl -> ShowS)
-> (ConnectionControl -> String)
-> ([ConnectionControl] -> ShowS)
-> Show ConnectionControl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionControl -> ShowS
showsPrec :: Int -> ConnectionControl -> ShowS
$cshow :: ConnectionControl -> String
show :: ConnectionControl -> String
$cshowList :: [ConnectionControl] -> ShowS
showList :: [ConnectionControl] -> ShowS
Show)

controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
    | Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn = do
        Connection -> IO ()
waitEstablished Connection
conn
        Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
    | Bool
otherwise = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 1" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
            Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    Int
x <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
    Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000 -- nearly 0
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 2" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Maybe CIDInfo
mcidinfo -> do
            Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
            Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
    UDPSocket
cs0 <- Connection -> IO UDPSocket
getSocket Connection
conn
    UDPSocket
cs <- UDPSocket -> IO UDPSocket
natRebinding UDPSocket
cs0
    UDPSocket
cs0' <- Connection -> UDPSocket -> IO UDPSocket
setSocket Connection
conn UDPSocket
cs
    let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
cs Connection
conn
    IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
    -- Using cs0' just in case.
    Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ UDPSocket -> IO ()
close UDPSocket
cs0'