{-# LANGUAGE BinaryLiterals #-}

module Network.QUIC.Packet.Header (
    isLong,
    isShort,
    protectFlags,
    unprotectFlags,
    encodeLongHeaderFlags,
    encodeShortHeaderFlags,
    decodeLongHeaderPacketType,
    encodePktNumLength,
    decodePktNumLength,
    versionNegotiationPacketType,
    retryPacketType,
) where

import Network.QUIC.Imports
import Network.QUIC.Types

{-# INLINE isLong #-}
isLong :: Word8 -> Bool
isLong :: Word8 -> Bool
isLong Word8
flags = Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
flags Int
7

{-# INLINE isShort #-}
isShort :: Flags Protected -> Bool
isShort :: Flags Protected -> Bool
isShort (Flags Word8
flags) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
flags Int
7

----------------------------------------------------------------

unprotectFlags :: Flags Protected -> Word8 -> Flags Raw
unprotectFlags :: Flags Protected -> Word8 -> Flags Raw
unprotectFlags (Flags Word8
proFlags) Word8
mask1 = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
flags
  where
    mask :: Word8
mask = Word8
mask1 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8 -> Word8
flagBits Word8
proFlags
    flags :: Word8
flags = Word8
proFlags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
mask

protectFlags :: Flags Raw -> Word8 -> Flags Protected
protectFlags :: Flags Raw -> Word8 -> Flags Protected
protectFlags (Flags Word8
flags) Word8
mask1 = Word8 -> Flags Protected
forall a. Word8 -> Flags a
Flags Word8
proFlags
  where
    mask :: Word8
mask = Word8
mask1 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8 -> Word8
flagBits Word8
flags
    proFlags :: Word8
proFlags = Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
mask

{-# INLINE flagBits #-}
{- FOURMOLU_DISABLE -}
flagBits :: Word8 -> Word8
flagBits :: Word8 -> Word8
flagBits Word8
flags
    | Word8 -> Bool
isLong Word8
flags = Word8
0b00001111 -- long header
    | Bool
otherwise    = Word8
0b00011111 -- short header
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

randomizeQuicBit :: Word8 -> Bool -> IO Word8
randomizeQuicBit :: Word8 -> Bool -> IO Word8
randomizeQuicBit Word8
flags Bool
quicBit
    | Bool
quicBit = do
        Word8
r <- IO Word8
getRandomOneByte
        Word8 -> IO Word8
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b10111111) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b01000000))
    | Bool
otherwise = Word8 -> IO Word8
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word8
flags

{-# INLINE encodeShortHeaderFlags #-}
{- FOURMOLU_DISABLE -}
encodeShortHeaderFlags
    :: Flags Raw -> Flags Raw -> Bool -> Bool -> IO (Flags Raw)
encodeShortHeaderFlags :: Flags Raw -> Flags Raw -> Bool -> Bool -> IO (Flags Raw)
encodeShortHeaderFlags (Flags Word8
fg) (Flags Word8
pp) Bool
quicBit Bool
keyPhase =
    Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags (Word8 -> Flags Raw) -> IO Word8 -> IO (Flags Raw)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word8 -> Bool -> IO Word8
randomizeQuicBit Word8
flags Bool
quicBit
  where
    flags :: Word8
flags =
                        Word8
0b01000000
            Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
fg Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00111100)
            Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
pp Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00000011)
            Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (if Bool
keyPhase then Word8
0b00000100 else Word8
0b00000000)
{- FOURMOLU_ENABLE -}

{-# INLINE encodeLongHeaderFlags #-}
encodeLongHeaderFlags
    :: Version
    -> LongHeaderPacketType
    -> Flags Raw
    -> Flags Raw
    -> Bool
    -> IO (Flags Raw)
encodeLongHeaderFlags :: Version
-> LongHeaderPacketType
-> Flags Raw
-> Flags Raw
-> Bool
-> IO (Flags Raw)
encodeLongHeaderFlags Version
ver LongHeaderPacketType
typ (Flags Word8
fg) (Flags Word8
pp) Bool
quicBit =
    Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags (Word8 -> Flags Raw) -> IO Word8 -> IO (Flags Raw)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word8 -> Bool -> IO Word8
randomizeQuicBit Word8
flags Bool
quicBit
  where
    Flags Word8
tp = Version -> LongHeaderPacketType -> Flags Raw
longHeaderPacketType Version
ver LongHeaderPacketType
typ
    flags :: Word8
flags =
        Word8
tp
            Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
fg Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00001100)
            Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
pp Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00000011)

{-# INLINE longHeaderPacketType #-}
{- FOURMOLU_DISABLE -}
longHeaderPacketType :: Version -> LongHeaderPacketType -> Flags Raw
longHeaderPacketType :: Version -> LongHeaderPacketType -> Flags Raw
longHeaderPacketType Version
Version2 LongHeaderPacketType
InitialPacketType   = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11010000
longHeaderPacketType Version
Version2 LongHeaderPacketType
RTT0PacketType      = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11100000
longHeaderPacketType Version
Version2 LongHeaderPacketType
HandshakePacketType = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11110000
longHeaderPacketType Version
Version2 LongHeaderPacketType
RetryPacketType     = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11000000
longHeaderPacketType Version
_        LongHeaderPacketType
InitialPacketType   = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11000000
longHeaderPacketType Version
_        LongHeaderPacketType
RTT0PacketType      = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11010000
longHeaderPacketType Version
_        LongHeaderPacketType
HandshakePacketType = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11100000
longHeaderPacketType Version
_        LongHeaderPacketType
RetryPacketType     = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0b11110000
{- FOURMOLU_ENABLE -}

retryPacketType :: Version -> IO (Flags Raw)
retryPacketType :: Version -> IO (Flags Raw)
retryPacketType Version
Version2 = do
    Word8
r <- IO Word8
getRandomOneByte
    let flags :: Word8
flags = Word8
0b11000000 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00001111)
    Flags Raw -> IO (Flags Raw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Flags Raw -> IO (Flags Raw)) -> Flags Raw -> IO (Flags Raw)
forall a b. (a -> b) -> a -> b
$ Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
flags
retryPacketType Version
_ = do
    Word8
r <- IO Word8
getRandomOneByte
    let flags :: Word8
flags = Word8
0b11110000 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00001111)
    Flags Raw -> IO (Flags Raw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Flags Raw -> IO (Flags Raw)) -> Flags Raw -> IO (Flags Raw)
forall a b. (a -> b) -> a -> b
$ Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
flags

versionNegotiationPacketType :: IO (Flags Raw)
versionNegotiationPacketType :: IO (Flags Raw)
versionNegotiationPacketType = do
    Word8
r <- IO Word8
getRandomOneByte
    let flags :: Word8
flags = Word8
0b10000000 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b01111111)
    Flags Raw -> IO (Flags Raw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Flags Raw -> IO (Flags Raw)) -> Flags Raw -> IO (Flags Raw)
forall a b. (a -> b) -> a -> b
$ Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
flags

{-# INLINE decodeLongHeaderPacketType #-}
{- FOURMOLU_DISABLE -}
decodeLongHeaderPacketType :: Version -> Flags Protected -> LongHeaderPacketType
decodeLongHeaderPacketType :: Version -> Flags Protected -> LongHeaderPacketType
decodeLongHeaderPacketType Version
Version2 (Flags Word8
flags) = case Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00110000 of
    Word8
0b00010000 -> LongHeaderPacketType
InitialPacketType
    Word8
0b00100000 -> LongHeaderPacketType
RTT0PacketType
    Word8
0b00110000 -> LongHeaderPacketType
HandshakePacketType
    Word8
_          -> LongHeaderPacketType
RetryPacketType
decodeLongHeaderPacketType Version
_ (Flags Word8
flags) = case Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00110000 of
    Word8
0b00000000 -> LongHeaderPacketType
InitialPacketType
    Word8
0b00010000 -> LongHeaderPacketType
RTT0PacketType
    Word8
0b00100000 -> LongHeaderPacketType
HandshakePacketType
    Word8
_          -> LongHeaderPacketType
RetryPacketType
{- FOURMOLU_ENABLE -}

{-# INLINE encodePktNumLength #-}
encodePktNumLength :: Int -> Flags Raw
encodePktNumLength :: Int -> Flags Raw
encodePktNumLength Int
epnLen = Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags (Word8 -> Flags Raw) -> Word8 -> Flags Raw
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
epnLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

{-# INLINE decodePktNumLength #-}
decodePktNumLength :: Flags Raw -> Int
decodePktNumLength :: Flags Raw -> Int
decodePktNumLength (Flags Word8
flags) = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b11) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1