{-# LANGUAGE OverloadedStrings #-}

-- |
-- API to run the TLS handshake establishing a QUIC connection.
--
-- On the northbound API:
--
-- * QUIC starts a TLS client or server thread with 'tlsQUICClient' or
--   'tlsQUICServer'.
--
--  TLS invokes QUIC callbacks to use the QUIC transport
--
-- * TLS uses 'quicSend' and 'quicRecv' to send and receive handshake message
--   fragments.
--
-- * TLS calls 'quicInstallKeys' to provide to QUIC the traffic secrets it
--   should use for encryption/decryption.
--
-- * TLS calls 'quicNotifyExtensions' to notify to QUIC the transport parameters
--   exchanged through the handshake protocol.
--
-- * TLS calls 'quicDone' when the handshake is done.
module Network.TLS.QUIC (
    -- * Handshakers
    tlsQUICClient,
    tlsQUICServer,

    -- * Callback
    QUICCallbacks (..),
    CryptLevel (..),
    KeyScheduleEvent (..),

    -- * Secrets
    EarlySecretInfo (..),
    HandshakeSecretInfo (..),
    ApplicationSecretInfo (..),
    EarlySecret,
    HandshakeSecret,
    ApplicationSecret,
    TrafficSecrets,
    ServerTrafficSecret (..),
    ClientTrafficSecret (..),

    -- * Negotiated parameters
    NegotiatedProtocol,
    HandshakeMode13 (..),

    -- * Extensions
    ExtensionRaw (..),
    ExtensionID (ExtensionID, EID_QuicTransportParameters),

    -- * Errors
    errorTLS,
    errorToAlertDescription,
    errorToAlertMessage,
    fromAlertDescription,
    toAlertDescription,

    -- * Hash
    hkdfExpandLabel,
    hkdfExtract,
    hashDigestSize,

    -- * Constants
    quicMaxEarlyDataSize,

    -- * Supported
    defaultSupported,
) where

import Network.TLS.Backend
import Network.TLS.Context
import Network.TLS.Context.Internal
import Network.TLS.Core
import Network.TLS.Crypto (hashDigestSize)
import Network.TLS.Crypto.Types
import Network.TLS.Extra.Cipher
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExpandLabel, hkdfExtract)
import Network.TLS.Parameters
import Network.TLS.Record.Layer
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types

import Data.Default.Class

nullBackend :: Backend
nullBackend :: Backend
nullBackend =
    Backend
        { backendFlush :: IO ()
backendFlush = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendClose :: IO ()
backendClose = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendSend :: ByteString -> IO ()
backendSend = \ByteString
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendRecv :: Int -> IO ByteString
backendRecv = \Int
_ -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
        }

-- | Argument given to 'quicInstallKeys' when encryption material is available.
data KeyScheduleEvent
    = -- | Key material and parameters for traffic at 0-RTT level
      InstallEarlyKeys (Maybe EarlySecretInfo)
    | -- | Key material and parameters for traffic at handshake level
      InstallHandshakeKeys HandshakeSecretInfo
    | -- | Key material and parameters for traffic at application level
      InstallApplicationKeys ApplicationSecretInfo

-- | Callbacks implemented by QUIC and to be called by TLS at specific points
-- during the handshake.  TLS may invoke them from external threads but calls
-- are not concurrent.  Only a single callback function is called at a given
-- point in time.
data QUICCallbacks = QUICCallbacks
    { QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend :: [(CryptLevel, ByteString)] -> IO ()
    -- ^ Called by TLS so that QUIC sends one or more handshake fragments. The
    -- content transiting on this API is the plaintext of the fragments and
    -- QUIC responsability is to encrypt this payload with the key material
    -- given for the specified level and an appropriate encryption scheme.
    --
    -- The size of the fragments may exceed QUIC datagram limits so QUIC may
    -- break them into smaller fragments.
    --
    -- The handshake protocol sometimes combines content at two levels in a
    -- single flight.  The TLS library does its best to provide this in the
    -- same @quicSend@ call and with a multi-valued argument.  QUIC can then
    -- decide how to transmit this optimally.
    , QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
    -- ^ Called by TLS to receive from QUIC the next plaintext handshake
    -- fragment.  The argument specifies with which encryption level the
    -- fragment should be decrypted.
    --
    -- QUIC may return partial fragments to TLS.  TLS will then call
    -- @quicRecv@ again as long as necessary.  Note however that fragments
    -- must be returned in the correct sequence, i.e. the order the TLS peer
    -- emitted them.
    --
    -- The function may return an error to TLS if end of stream is reached or
    -- if a protocol error has been received, believing the handshake cannot
    -- proceed any longer.  If the TLS handshake protocol cannot recover from
    -- this error, the failure condition will be reported back to QUIC through
    -- the control interface.
    , QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
    -- ^ Called by TLS when new encryption material is ready to be used in the
    -- handshake.  The next 'quicSend' or 'quicRecv' may now use the
    -- associated encryption level (although the previous level is also
    -- possible: directions Send/Recv do not change at the same time).
    , QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
    -- ^ Called by TLS when QUIC-specific extensions have been received from
    -- the peer.
    , QUICCallbacks -> Context -> IO ()
quicDone :: Context -> IO ()
    -- ^ Called when 'handshake' is done. 'tlsQUICServer' is
    -- finished after calling this hook. 'tlsQUICClient' calls
    -- 'recvData' after calling this hook to wait for new session
    -- tickets.
    }

newRecordLayer
    :: QUICCallbacks
    -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer :: QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks = (Context -> IO CryptLevel)
-> ([(CryptLevel, ByteString)] -> IO ())
-> (Context -> IO (Either TLSError ByteString))
-> RecordLayer [(CryptLevel, ByteString)]
forall ann.
Eq ann =>
(Context -> IO ann)
-> ([(ann, ByteString)] -> IO ())
-> (Context -> IO (Either TLSError ByteString))
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer Context -> IO CryptLevel
get [(CryptLevel, ByteString)] -> IO ()
send Context -> IO (Either TLSError ByteString)
recv
  where
    get :: Context -> IO CryptLevel
get = Context -> IO CryptLevel
getTxLevel
    send :: [(CryptLevel, ByteString)] -> IO ()
send = QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend QUICCallbacks
callbacks
    recv :: Context -> IO (Either TLSError ByteString)
recv Context
ctx = Context -> IO CryptLevel
getRxLevel Context
ctx IO CryptLevel
-> (CryptLevel -> IO (Either TLSError ByteString))
-> IO (Either TLSError ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv QUICCallbacks
callbacks

-- | Start a TLS handshake thread for a QUIC client.  The client will use the
-- specified TLS parameters and call the provided callback functions to send and
-- receive handshake data.
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient ClientParams
cparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ClientParams
cparams
    let ctx1 :: Context
ctx1 =
            Context
ctx0
                { ctxHandshakeSync = HandshakeSync sync (\Context
_ ServerState
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                , ctxFragmentSize = Nothing
                , ctxQUICMode = True
                }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall a. Monoid a => RecordLayer a -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
    IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx2 -- waiting for new session tickets
  where
    sync :: Context -> ClientState -> IO ()
sync Context
ctx (SendClientHello Maybe EarlySecretInfo
mEarlySecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
    sync Context
ctx (RecvServerHello HandshakeSecretInfo
handSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendClientFinished [ExtensionRaw]
exts ApplicationSecretInfo
appSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"QUIC transport parameters are mssing" AlertDescription
MissingExtension
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)

-- | Start a TLS handshake thread for a QUIC server.  The server will use the
-- specified TLS parameters and call the provided callback functions to send and
-- receive handshake data.
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ServerParams
sparams
    let ctx1 :: Context
ctx1 =
            Context
ctx0
                { ctxHandshakeSync = HandshakeSync (\Context
_ ClientState
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) sync
                , ctxFragmentSize = Nothing
                , ctxQUICMode = True
                }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall a. Monoid a => RecordLayer a -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
  where
    sync :: Context -> ServerState -> IO ()
sync Context
ctx (SendServerHello [ExtensionRaw]
exts Maybe EarlySecretInfo
mEarlySecInfo HandshakeSecretInfo
handSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"QUIC transport parameters are mssing" AlertDescription
MissingExtension
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendServerFinished ApplicationSecretInfo
appSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)

filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP =
    (ExtensionRaw -> Bool) -> [ExtensionRaw] -> [ExtensionRaw]
forall a. (a -> Bool) -> [a] -> [a]
filter
        (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_QuicTransportParameters)

-- | Can be used by callbacks to signal an unexpected condition.  This will then
-- generate an "internal_error" alert in the TLS stack.
errorTLS :: String -> IO a
errorTLS :: forall a. String -> IO a
errorTLS String
msg = TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
msg AlertDescription
InternalError

-- | Return the alert that a TLS endpoint would send to the peer for the
-- specified library error.
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription = (AlertLevel, AlertDescription) -> AlertDescription
forall a b. (a, b) -> b
snd ((AlertLevel, AlertDescription) -> AlertDescription)
-> (TLSError -> (AlertLevel, AlertDescription))
-> TLSError
-> AlertDescription
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> (AlertLevel, AlertDescription)
errorToAlert

-- | Decode an alert from the assigned value.
toAlertDescription :: Word8 -> AlertDescription
toAlertDescription :: Word8 -> AlertDescription
toAlertDescription = Word8 -> AlertDescription
AlertDescription

defaultSupported :: Supported
defaultSupported :: Supported
defaultSupported =
    Supported
forall a. Default a => a
def
        { supportedVersions = [TLS13]
        , supportedCiphers =
            [ cipher_TLS13_AES256GCM_SHA384
            , cipher_TLS13_AES128GCM_SHA256
            , cipher_TLS13_AES128CCM_SHA256
            ]
        , supportedGroups = [X25519, X448, P256, P384, P521]
        }

-- | Max early data size for QUIC.
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize = Int
0xffffffff