{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.State13
( CryptLevel ( CryptEarlySecret
, CryptHandshakeSecret
, CryptApplicationSecret
)
, TrafficSecret
, getTxState
, getRxState
, setTxState
, setRxState
, clearTxState
, clearRxState
, setHelloParameters13
, transcriptHash
, wrapAsMessageHash13
, PendingAction(..)
, setPendingActions
, popPendingAction
) where
import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util
getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxTxState
getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxRxState
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
func = do
RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
let Just Cipher
usedCipher = RecordState -> Maybe Cipher
stCipher RecordState
tx
usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
level :: CryptLevel
level = RecordState -> CryptLevel
stCryptLevel RecordState
tx
secret :: ByteString
secret = CryptState -> ByteString
cstMacSecret (CryptState -> ByteString) -> CryptState -> ByteString
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptState
stCryptState RecordState
tx
(Hash, Cipher, CryptLevel, ByteString)
-> IO (Hash, Cipher, CryptLevel, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Hash
usedHash, Cipher
usedCipher, CryptLevel
level, ByteString
secret)
class TrafficSecret ty where
fromTrafficSecret :: ty -> (CryptLevel, ByteString)
instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
fromTrafficSecret :: AnyTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: AnyTrafficSecret a
prx@(AnyTrafficSecret ByteString
s) = (AnyTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
getCryptLevel AnyTrafficSecret a
prx, ByteString
s)
instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
fromTrafficSecret :: ClientTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ClientTrafficSecret a
prx@(ClientTrafficSecret ByteString
s) = (ClientTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
getCryptLevel ClientTrafficSecret a
prx, ByteString
s)
instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
fromTrafficSecret :: ServerTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ServerTrafficSecret a
prx@(ServerTrafficSecret ByteString
s) = (ServerTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
getCryptLevel ServerTrafficSecret a
prx, ByteString
s)
setTxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxState :: Context -> Hash -> Cipher -> ty -> IO ()
setTxState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxTxState BulkDirection
BulkEncrypt
setRxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxState :: Context -> Hash -> Cipher -> ty -> IO ()
setRxState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxRxState BulkDirection
BulkDecrypt
setXState :: TrafficSecret ty
=> (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> ty
-> IO ()
setXState :: (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher ty
ts =
let (CryptLevel
lvl, ByteString
secret) = ty -> (CryptLevel, ByteString)
forall ty. TrafficSecret ty => ty -> (CryptLevel, ByteString)
fromTrafficSecret ty
ts
in (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret
setXState' :: (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> CryptLevel -> ByteString
-> IO ()
setXState' :: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt)
where
bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"key" ByteString
"" Int
keySize
iv :: ByteString
iv = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"iv" ByteString
"" Int
ivSize
cst :: CryptState
cst = CryptState :: BulkState -> ByteString -> ByteString -> CryptState
CryptState {
cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk BulkDirection
encOrDec ByteString
key
, cstIV :: ByteString
cstIV = ByteString
iv
, cstMacSecret :: ByteString
cstMacSecret = ByteString
secret
}
rt :: RecordState
rt = RecordState :: Maybe Cipher
-> Compression
-> CryptLevel
-> CryptState
-> MacState
-> RecordState
RecordState {
stCryptState :: CryptState
stCryptState = CryptState
cst
, stMacState :: MacState
stMacState = MacState :: Word64 -> MacState
MacState { msSequence :: Word64
msSequence = Word64
0 }
, stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
lvl
, stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
, stCompression :: Compression
stCompression = Compression
nullCompression
}
clearTxState :: Context -> IO ()
clearTxState :: Context -> IO ()
clearTxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxTxState
clearRxState :: Context -> IO ()
clearRxState :: Context -> IO ()
clearRxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxRxState
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
func Context
ctx =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
rt -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt { stCipher :: Maybe Cipher
stCipher = Maybe Cipher
forall a. Maybe a
Nothing })
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
cipher = do
HandshakeState
hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
case HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst of
Maybe Cipher
Nothing -> do
HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put HandshakeState
hst {
hstPendingCipher :: Maybe Cipher
hstPendingCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
, hstPendingCompression :: Compression
hstPendingCompression = Compression
nullCompression
, hstHandshakeDigest :: HandshakeDigest
hstHandshakeDigest = HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeDigest -> HandshakeDigest)
-> HandshakeDigest -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst
}
Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
Just Cipher
oldcipher
| Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
| Bool
otherwise -> Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"TLS 1.3 cipher changed after hello retry", Bool
True, AlertDescription
IllegalParameter)
where
hashAlg :: Hash
hashAlg = Cipher -> Hash
cipherHash Cipher
cipher
updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes) = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
Cipher
cipher <- HandshakeM Cipher
getPendingCipher
Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest (Cipher -> Hash
cipherHash Cipher
cipher) ByteString -> ByteString
foldFunc
where
foldFunc :: ByteString -> ByteString
foldFunc ByteString
dig = [ByteString] -> ByteString
B.concat [ ByteString
"\254\0\0"
, Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
dig)
, ByteString
dig
]
transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash :: Context -> m ByteString
transcriptHash Context
ctx = do
HandshakeState
hst <- String -> Maybe HandshakeState -> HandshakeState
forall a. String -> Maybe a -> a
fromJust String
"HState" (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
HandshakeDigestContext HashCtx
hashCtx -> ByteString -> m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString
hashFinal HashCtx
hashCtx
HandshakeMessages [ByteString]
_ -> String -> m ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions Context
ctx = IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef [PendingAction]
ctxPendingActions Context
ctx)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction Context
ctx = do
let ref :: IORef [PendingAction]
ref = Context -> IORef [PendingAction]
ctxPendingActions Context
ctx
[PendingAction]
actions <- IORef [PendingAction] -> IO [PendingAction]
forall a. IORef a -> IO a
readIORef IORef [PendingAction]
ref
case [PendingAction]
actions of
PendingAction
bs:[PendingAction]
bss -> IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [PendingAction]
ref [PendingAction]
bss IO () -> IO (Maybe PendingAction) -> IO (Maybe PendingAction)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PendingAction -> IO (Maybe PendingAction)
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingAction -> Maybe PendingAction
forall a. a -> Maybe a
Just PendingAction
bs)
[] -> Maybe PendingAction -> IO (Maybe PendingAction)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PendingAction
forall a. Maybe a
Nothing