{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.QUIC (
tlsQUICClient
, tlsQUICServer
, QUICCallbacks(..)
, CryptLevel(..)
, KeyScheduleEvent(..)
, EarlySecretInfo(..)
, HandshakeSecretInfo(..)
, ApplicationSecretInfo(..)
, EarlySecret
, HandshakeSecret
, ApplicationSecret
, TrafficSecrets
, ServerTrafficSecret(..)
, ClientTrafficSecret(..)
, NegotiatedProtocol
, HandshakeMode13(..)
, ExtensionRaw(..)
, ExtensionID
, extensionID_QuicTransportParameters
, errorTLS
, errorToAlertDescription
, errorToAlertMessage
, fromAlertDescription
, toAlertDescription
, hkdfExpandLabel
, hkdfExtract
, hashDigestSize
, quicMaxEarlyDataSize
, defaultSupported
) where
import Network.TLS.Backend
import Network.TLS.Context
import Network.TLS.Context.Internal
import Network.TLS.Core
import Network.TLS.Crypto (hashDigestSize)
import Network.TLS.Crypto.Types
import Network.TLS.Extension (extensionID_QuicTransportParameters)
import Network.TLS.Extra.Cipher
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExtract, hkdfExpandLabel)
import Network.TLS.Parameters
import Network.TLS.Record.Layer
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types
import Data.Default.Class
nullBackend :: Backend
nullBackend :: Backend
nullBackend = Backend :: IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend {
backendFlush :: IO ()
backendFlush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
, backendClose :: IO ()
backendClose = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
, backendSend :: ByteString -> IO ()
backendSend = \ByteString
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
, backendRecv :: Int -> IO ByteString
backendRecv = \Int
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
}
data KeyScheduleEvent
= InstallEarlyKeys (Maybe EarlySecretInfo)
| InstallHandshakeKeys HandshakeSecretInfo
| InstallApplicationKeys ApplicationSecretInfo
data QUICCallbacks = QUICCallbacks
{ QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend :: [(CryptLevel, ByteString)] -> IO ()
, QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
, QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
, QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
, QUICCallbacks -> Context -> IO ()
quicDone :: Context -> IO ()
}
getTxLevel :: Context -> IO CryptLevel
getTxLevel :: Context -> IO CryptLevel
getTxLevel Context
ctx = do
(Hash
_, Cipher
_, CryptLevel
level, ByteString
_) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState Context
ctx
CryptLevel -> IO CryptLevel
forall (m :: * -> *) a. Monad m => a -> m a
return CryptLevel
level
getRxLevel :: Context -> IO CryptLevel
getRxLevel :: Context -> IO CryptLevel
getRxLevel Context
ctx = do
(Hash
_, Cipher
_, CryptLevel
level, ByteString
_) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState Context
ctx
CryptLevel -> IO CryptLevel
forall (m :: * -> *) a. Monad m => a -> m a
return CryptLevel
level
newRecordLayer :: Context -> QUICCallbacks
-> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer :: Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx QUICCallbacks
callbacks = IO CryptLevel
-> ([(CryptLevel, ByteString)] -> IO ())
-> IO (Either TLSError ByteString)
-> RecordLayer [(CryptLevel, ByteString)]
forall ann.
Eq ann =>
IO ann
-> ([(ann, ByteString)] -> IO ())
-> IO (Either TLSError ByteString)
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer IO CryptLevel
get [(CryptLevel, ByteString)] -> IO ()
send IO (Either TLSError ByteString)
recv
where
get :: IO CryptLevel
get = Context -> IO CryptLevel
getTxLevel Context
ctx
send :: [(CryptLevel, ByteString)] -> IO ()
send = QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend QUICCallbacks
callbacks
recv :: IO (Either TLSError ByteString)
recv = Context -> IO CryptLevel
getRxLevel Context
ctx IO CryptLevel
-> (CryptLevel -> IO (Either TLSError ByteString))
-> IO (Either TLSError ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv QUICCallbacks
callbacks
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient ClientParams
cparams QUICCallbacks
callbacks = do
Context
ctx0 <- Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ClientParams
cparams
let ctx1 :: Context
ctx1 = Context
ctx0
{ ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync Context -> ClientState -> IO ()
sync (\Context
_ ServerState
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
, ctxFragmentSize :: Maybe Int
ctxFragmentSize = Maybe Int
forall a. Maybe a
Nothing
, ctxQUICMode :: Bool
ctxQUICMode = Bool
True
}
rl :: RecordLayer [(CryptLevel, ByteString)]
rl = Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx2 QUICCallbacks
callbacks
ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall bytes.
Monoid bytes =>
RecordLayer bytes -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx2
where
sync :: Context -> ClientState -> IO ()
sync Context
ctx (SendClientHello Maybe EarlySecretInfo
mEarlySecInfo) =
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
sync Context
ctx (RecvServerHello HandshakeSecretInfo
handSecInfo) =
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
sync Context
ctx (SendClientFinished [ExtensionRaw]
exts ApplicationSecretInfo
appSecInfo) = do
let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"QUIC transport parameters are mssing", Bool
True, AlertDescription
MissingExtension)
QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks = do
Context
ctx0 <- Backend -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ServerParams
sparams
let ctx1 :: Context
ctx1 = Context
ctx0
{ ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync (\Context
_ ClientState
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) Context -> ServerState -> IO ()
sync
, ctxFragmentSize :: Maybe Int
ctxFragmentSize = Maybe Int
forall a. Maybe a
Nothing
, ctxQUICMode :: Bool
ctxQUICMode = Bool
True
}
rl :: RecordLayer [(CryptLevel, ByteString)]
rl = Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx2 QUICCallbacks
callbacks
ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall bytes.
Monoid bytes =>
RecordLayer bytes -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
where
sync :: Context -> ServerState -> IO ()
sync Context
ctx (SendServerHello [ExtensionRaw]
exts Maybe EarlySecretInfo
mEarlySecInfo HandshakeSecretInfo
handSecInfo) = do
let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"QUIC transport parameters are mssing", Bool
True, AlertDescription
MissingExtension)
QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
sync Context
ctx (SendServerFinished ApplicationSecretInfo
appSecInfo) =
QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP = (ExtensionRaw -> Bool) -> [ExtensionRaw] -> [ExtensionRaw]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
extensionID_QuicTransportParameters Bool -> Bool -> Bool
|| ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
0xffa5)
errorTLS :: String -> IO a
errorTLS :: String -> IO a
errorTLS String
msg = TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
msg, Bool
True, AlertDescription
InternalError)
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription = (AlertLevel, AlertDescription) -> AlertDescription
forall a b. (a, b) -> b
snd ((AlertLevel, AlertDescription) -> AlertDescription)
-> (TLSError -> (AlertLevel, AlertDescription))
-> TLSError
-> AlertDescription
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(AlertLevel, AlertDescription)] -> (AlertLevel, AlertDescription)
forall a. [a] -> a
head ([(AlertLevel, AlertDescription)]
-> (AlertLevel, AlertDescription))
-> (TLSError -> [(AlertLevel, AlertDescription)])
-> TLSError
-> (AlertLevel, AlertDescription)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> [(AlertLevel, AlertDescription)]
errorToAlert
fromAlertDescription :: AlertDescription -> Word8
fromAlertDescription :: AlertDescription -> Word8
fromAlertDescription = AlertDescription -> Word8
forall a. TypeValuable a => a -> Word8
valOfType
toAlertDescription :: Word8 -> Maybe AlertDescription
toAlertDescription :: Word8 -> Maybe AlertDescription
toAlertDescription = Word8 -> Maybe AlertDescription
forall a. TypeValuable a => Word8 -> Maybe a
valToType
defaultSupported :: Supported
defaultSupported :: Supported
defaultSupported = Supported
forall a. Default a => a
def
{ supportedVersions :: [Version]
supportedVersions = [Version
TLS13]
, supportedCiphers :: [Cipher]
supportedCiphers = [ Cipher
cipher_TLS13_AES256GCM_SHA384
, Cipher
cipher_TLS13_AES128GCM_SHA256
, Cipher
cipher_TLS13_AES128CCM_SHA256
]
, supportedGroups :: [Group]
supportedGroups = [Group
X25519,Group
X448,Group
P256,Group
P384,Group
P521]
}
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize = Int
0xffffffff