{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Packet.Decode (
    decodePacket,
    decodePackets,
    decodeCryptPackets,
    decodeStatelessResetToken,
) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as Short
import qualified UnliftIO.Exception as E

import Network.QUIC.Imports
import Network.QUIC.Packet.Header
import Network.QUIC.Types

----------------------------------------------------------------

-- Server uses this.
decodeCryptPackets :: ByteString -> IO [(CryptPacket, EncryptionLevel, Int)]
decodeCryptPackets :: ByteString -> IO [(CryptPacket, EncryptionLevel, Int)]
decodeCryptPackets ByteString
bs0 = [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap ([PacketI] -> [(CryptPacket, EncryptionLevel, Int)])
-> IO [PacketI] -> IO [(CryptPacket, EncryptionLevel, Int)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO [PacketI]
decodePackets ByteString
bs0
  where
    unwrap :: [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap (PacketIC CryptPacket
c EncryptionLevel
l Int
s : [PacketI]
xs) = (CryptPacket
c, EncryptionLevel
l, Int
s) (CryptPacket, EncryptionLevel, Int)
-> [(CryptPacket, EncryptionLevel, Int)]
-> [(CryptPacket, EncryptionLevel, Int)]
forall a. a -> [a] -> [a]
: [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap [PacketI]
xs
    unwrap (PacketI
_ : [PacketI]
xs) = [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap [PacketI]
xs
    unwrap [] = []

-- Client uses this.
decodePackets :: ByteString -> IO [PacketI]
decodePackets :: ByteString -> IO [PacketI]
decodePackets ByteString
bs0 = ByteString -> ([PacketI] -> [PacketI]) -> IO [PacketI]
forall {c}. ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
bs0 [PacketI] -> [PacketI]
forall a. a -> a
id
  where
    loop :: ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
"" [PacketI] -> c
build = c -> IO c
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (c -> IO c) -> c -> IO c
forall a b. (a -> b) -> a -> b
$ [PacketI] -> c
build [] -- fixme
    loop ByteString
bs [PacketI] -> c
build = do
        (PacketI
pkt, ByteString
rest) <- ByteString -> IO (PacketI, ByteString)
decodePacket ByteString
bs
        ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
rest ([PacketI] -> c
build ([PacketI] -> c) -> ([PacketI] -> [PacketI]) -> [PacketI] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PacketI
pkt PacketI -> [PacketI] -> [PacketI]
forall a. a -> [a] -> [a]
:))

decodePacket :: ByteString -> IO (PacketI, ByteString)
decodePacket :: ByteString -> IO (PacketI, ByteString)
decodePacket ByteString
bs = (BufferOverrun -> IO (PacketI, ByteString))
-> IO (PacketI, ByteString) -> IO (PacketI, ByteString)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
E.handle BufferOverrun -> IO (PacketI, ByteString)
forall {m :: * -> *} {b}.
(Monad m, IsString b) =>
BufferOverrun -> m (PacketI, b)
handler (IO (PacketI, ByteString) -> IO (PacketI, ByteString))
-> IO (PacketI, ByteString) -> IO (PacketI, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
-> (ReadBuffer -> IO (PacketI, ByteString))
-> IO (PacketI, ByteString)
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO (PacketI, ByteString))
 -> IO (PacketI, ByteString))
-> (ReadBuffer -> IO (PacketI, ByteString))
-> IO (PacketI, ByteString)
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf -> do
    ReadBuffer -> IO ()
forall a. Readable a => a -> IO ()
save ReadBuffer
rbuf
    Flags Protected
proFlags <- Word8 -> Flags Protected
forall a. Word8 -> Flags a
Flags (Word8 -> Flags Protected) -> IO Word8 -> IO (Flags Protected)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    let short :: Bool
short = Flags Protected -> Bool
isShort Flags Protected
proFlags
    PacketI
pkt <- ReadBuffer -> Flags Protected -> Bool -> IO PacketI
decode ReadBuffer
rbuf Flags Protected
proFlags Bool
short
    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    let rest :: ByteString
rest = Int -> ByteString -> ByteString
BS.drop Int
siz ByteString
bs
    (PacketI, ByteString) -> IO (PacketI, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI
pkt, ByteString
rest)
  where
    decode :: ReadBuffer -> Flags Protected -> Bool -> IO PacketI
decode ReadBuffer
rbuf Flags Protected
_proFlags Bool
True = do
        Header
header <- CID -> Header
Short (CID -> Header)
-> (ShortByteString -> CID) -> ShortByteString -> Header
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShortByteString -> CID
makeCID (ShortByteString -> Header) -> IO ShortByteString -> IO Header
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> Int -> IO ShortByteString
forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
myCIDLength
        CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header (Crypt -> CryptPacket) -> IO Crypt -> IO CryptPacket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt ByteString
bs ReadBuffer
rbuf
        Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
        PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
RTT1Level Int
siz
    decode ReadBuffer
rbuf Flags Protected
proFlags Bool
False = do
        (Version
ver, CID
dCID, CID
sCID) <- ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader ReadBuffer
rbuf
        case Version
ver of
            Version
Negotiation -> do
                ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket ReadBuffer
rbuf CID
dCID CID
sCID
            Version
_ -> case Version -> Flags Protected -> LongHeaderPacketType
decodeLongHeaderPacketType Version
ver Flags Protected
proFlags of
                LongHeaderPacketType
RetryPacketType -> do
                    ReadBuffer
-> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket ReadBuffer
rbuf Flags Protected
proFlags Version
ver CID
dCID CID
sCID
                LongHeaderPacketType
RTT0PacketType -> do
                    let header :: Header
header = Version -> CID -> CID -> Header
RTT0 Version
ver CID
dCID CID
sCID
                    CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header (Crypt -> CryptPacket) -> IO Crypt -> IO CryptPacket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                    PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
RTT0Level Int
siz
                LongHeaderPacketType
InitialPacketType -> do
                    Int
tokenLen <- Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> IO Int64 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
                    ByteString
token <- ReadBuffer -> Int -> IO ByteString
forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf Int
tokenLen
                    let header :: Header
header = Version -> CID -> CID -> ByteString -> Header
Initial Version
ver CID
dCID CID
sCID ByteString
token
                    CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header (Crypt -> CryptPacket) -> IO Crypt -> IO CryptPacket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                    PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
InitialLevel Int
siz
                LongHeaderPacketType
HandshakePacketType -> do
                    let header :: Header
header = Version -> CID -> CID -> Header
Handshake Version
ver CID
dCID CID
sCID
                    CryptPacket
crypt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header (Crypt -> CryptPacket) -> IO Crypt -> IO CryptPacket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                    PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
crypt EncryptionLevel
HandshakeLevel Int
siz
    handler :: BufferOverrun -> m (PacketI, b)
handler BufferOverrun
BufferOverrun = (PacketI, b) -> m (PacketI, b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (BrokenPacket -> PacketI
PacketIB BrokenPacket
BrokenPacket, b
"")

makeShortCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt ByteString
bs ReadBuffer
rbuf = do
    Int
len <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    Int
here <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    ReadBuffer -> Int -> IO ()
forall a. Readable a => a -> Int -> IO ()
ff ReadBuffer
rbuf Int
len
    Crypt -> IO Crypt
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Crypt -> IO Crypt) -> Crypt -> IO Crypt
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> Int -> Maybe MigrationInfo -> Crypt
Crypt Int
here ByteString
bs Int
0 Maybe MigrationInfo
forall a. Maybe a
Nothing

makeLongCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf = do
    Int
len <- Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> IO Int64 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
    Int
here <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    ReadBuffer -> Int -> IO ()
forall a. Readable a => a -> Int -> IO ()
ff ReadBuffer
rbuf Int
len
    let pkt :: ByteString
pkt = Int -> ByteString -> ByteString
BS.take (Int
here Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) ByteString
bs
    Crypt -> IO Crypt
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Crypt -> IO Crypt) -> Crypt -> IO Crypt
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> Int -> Maybe MigrationInfo -> Crypt
Crypt Int
here ByteString
pkt Int
0 Maybe MigrationInfo
forall a. Maybe a
Nothing

----------------------------------------------------------------

decodeLongHeader :: ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader :: ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader ReadBuffer
rbuf = do
    Version
ver <- Word32 -> Version
Version (Word32 -> Version) -> IO Word32 -> IO Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word32
forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
    Int
dcidlen <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> IO Word8 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    CID
dCID <- ShortByteString -> CID
makeCID (ShortByteString -> CID) -> IO ShortByteString -> IO CID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> Int -> IO ShortByteString
forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
dcidlen
    Int
scidlen <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> IO Word8 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    CID
sCID <- ShortByteString -> CID
makeCID (ShortByteString -> CID) -> IO ShortByteString -> IO CID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> Int -> IO ShortByteString
forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
scidlen
    (Version, CID, CID) -> IO (Version, CID, CID)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
ver, CID
dCID, CID
sCID)

decodeVersionNegotiationPacket :: ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket :: ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket ReadBuffer
rbuf CID
dCID CID
sCID = do
    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    [Version]
vers <- Int -> ([Version] -> [Version]) -> IO [Version]
forall {t} {a}.
(Ord t, Num t) =>
t -> ([a] -> [Version]) -> IO [Version]
decodeVersions Int
siz [Version] -> [Version]
forall a. a -> a
id
    PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ VersionNegotiationPacket -> PacketI
PacketIV (VersionNegotiationPacket -> PacketI)
-> VersionNegotiationPacket -> PacketI
forall a b. (a -> b) -> a -> b
$ CID -> CID -> [Version] -> VersionNegotiationPacket
VersionNegotiationPacket CID
dCID CID
sCID [Version]
vers
  where
    decodeVersions :: t -> ([a] -> [Version]) -> IO [Version]
decodeVersions t
siz [a] -> [Version]
vers
        | t
siz t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
4 = do
            Version
ver <- Word32 -> Version
Version (Word32 -> Version) -> IO Word32 -> IO Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word32
forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
            t -> ([a] -> [Version]) -> IO [Version]
decodeVersions (t
siz t -> t -> t
forall a. Num a => a -> a -> a
- t
4) ((Version
ver Version -> [Version] -> [Version]
forall a. a -> [a] -> [a]
:) ([Version] -> [Version]) -> ([a] -> [Version]) -> [a] -> [Version]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [Version]
vers)
        | Bool
otherwise = [Version] -> IO [Version]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Version] -> IO [Version]) -> [Version] -> IO [Version]
forall a b. (a -> b) -> a -> b
$ [a] -> [Version]
vers []

decodeRetryPacket
    :: ReadBuffer -> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket :: ReadBuffer
-> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket ReadBuffer
rbuf Flags Protected
_proFlags Version
version CID
dCID CID
sCID = do
    Int
rsiz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    ByteString
token <- ReadBuffer -> Int -> IO ByteString
forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf (Int
rsiz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16)
    Int
siz <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    ByteString
pseudo <- ReadBuffer -> Int -> IO ByteString
forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Int
forall a. Num a => a -> a
negate Int
siz
    ByteString
tag <- ReadBuffer -> Int -> IO ByteString
forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf Int
16
    PacketI -> IO PacketI
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI -> IO PacketI) -> PacketI -> IO PacketI
forall a b. (a -> b) -> a -> b
$ RetryPacket -> PacketI
PacketIR (RetryPacket -> PacketI) -> RetryPacket -> PacketI
forall a b. (a -> b) -> a -> b
$ Version
-> CID
-> CID
-> ByteString
-> Either CID (ByteString, ByteString)
-> RetryPacket
RetryPacket Version
version CID
dCID CID
sCID ByteString
token ((ByteString, ByteString) -> Either CID (ByteString, ByteString)
forall a b. b -> Either a b
Right (ByteString
pseudo, ByteString
tag))

----------------------------------------------------------------

decodeStatelessResetToken :: ByteString -> Maybe StatelessResetToken
decodeStatelessResetToken :: ByteString -> Maybe StatelessResetToken
decodeStatelessResetToken ByteString
bs
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
21 = Maybe StatelessResetToken
forall a. Maybe a
Nothing
    | Bool
otherwise = StatelessResetToken -> Maybe StatelessResetToken
forall a. a -> Maybe a
Just (StatelessResetToken -> Maybe StatelessResetToken)
-> StatelessResetToken -> Maybe StatelessResetToken
forall a b. (a -> b) -> a -> b
$ ShortByteString -> StatelessResetToken
StatelessResetToken (ShortByteString -> StatelessResetToken)
-> ShortByteString -> StatelessResetToken
forall a b. (a -> b) -> a -> b
$ ByteString -> ShortByteString
Short.toShort ByteString
token
  where
    len :: Int
len = ByteString -> Int
BS.length ByteString
bs
    (ByteString
_, ByteString
token) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
bs