{-# LANGUAGE MultiParamTypeClasses #-}

module Network.TLS.Record.State (
    CryptState (..),
    CryptLevel (..),
    HasCryptLevel (..),
    MacState (..),
    RecordOptions (..),
    RecordState (..),
    newRecordState,
    incrRecordState,
    RecordM,
    runRecordM,
    getRecordOptions,
    getRecordVersion,
    setRecordIV,
    withCompression,
    computeDigest,
    makeDigest,
    getBulk,
    getMacSequence,
) where

import Control.Monad.State.Strict
import qualified Data.ByteString as B

import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.ErrT
import Network.TLS.Imports
import Network.TLS.MAC
import Network.TLS.Packet
import Network.TLS.Struct
import Network.TLS.Types
import Network.TLS.Wire

data CryptState = CryptState
    { CryptState -> BulkState
cstKey :: BulkState
    , CryptState -> ByteString
cstIV :: ByteString
    , -- In TLS 1.2 or earlier, this holds mac secret.
      -- In TLS 1.3, this holds application traffic secret N.
      CryptState -> ByteString
cstMacSecret :: ByteString
    }
    deriving (Int -> CryptState -> ShowS
[CryptState] -> ShowS
CryptState -> String
(Int -> CryptState -> ShowS)
-> (CryptState -> String)
-> ([CryptState] -> ShowS)
-> Show CryptState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CryptState -> ShowS
showsPrec :: Int -> CryptState -> ShowS
$cshow :: CryptState -> String
show :: CryptState -> String
$cshowList :: [CryptState] -> ShowS
showList :: [CryptState] -> ShowS
Show)

newtype MacState = MacState
    { MacState -> Word64
msSequence :: Word64
    }
    deriving (Int -> MacState -> ShowS
[MacState] -> ShowS
MacState -> String
(Int -> MacState -> ShowS)
-> (MacState -> String) -> ([MacState] -> ShowS) -> Show MacState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MacState -> ShowS
showsPrec :: Int -> MacState -> ShowS
$cshow :: MacState -> String
show :: MacState -> String
$cshowList :: [MacState] -> ShowS
showList :: [MacState] -> ShowS
Show)

data RecordOptions = RecordOptions
    { RecordOptions -> Version
recordVersion :: Version -- version to use when sending/receiving
    , RecordOptions -> Bool
recordTLS13 :: Bool -- TLS13 record processing
    }

-- | TLS encryption level.
data CryptLevel
    = -- | Unprotected traffic
      CryptInitial
    | -- | Protected with main secret (TLS < 1.3)
      CryptMainSecret
    | -- | Protected with early traffic secret (TLS 1.3)
      CryptEarlySecret
    | -- | Protected with handshake traffic secret (TLS 1.3)
      CryptHandshakeSecret
    | -- | Protected with application traffic secret (TLS 1.3)
      CryptApplicationSecret
    deriving (CryptLevel -> CryptLevel -> Bool
(CryptLevel -> CryptLevel -> Bool)
-> (CryptLevel -> CryptLevel -> Bool) -> Eq CryptLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CryptLevel -> CryptLevel -> Bool
== :: CryptLevel -> CryptLevel -> Bool
$c/= :: CryptLevel -> CryptLevel -> Bool
/= :: CryptLevel -> CryptLevel -> Bool
Eq, Int -> CryptLevel -> ShowS
[CryptLevel] -> ShowS
CryptLevel -> String
(Int -> CryptLevel -> ShowS)
-> (CryptLevel -> String)
-> ([CryptLevel] -> ShowS)
-> Show CryptLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CryptLevel -> ShowS
showsPrec :: Int -> CryptLevel -> ShowS
$cshow :: CryptLevel -> String
show :: CryptLevel -> String
$cshowList :: [CryptLevel] -> ShowS
showList :: [CryptLevel] -> ShowS
Show)

class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel :: forall (proxy :: * -> *). proxy EarlySecret -> CryptLevel
getCryptLevel proxy EarlySecret
_ = CryptLevel
CryptEarlySecret
instance HasCryptLevel HandshakeSecret where
    getCryptLevel :: forall (proxy :: * -> *). proxy HandshakeSecret -> CryptLevel
getCryptLevel proxy HandshakeSecret
_ = CryptLevel
CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where
    getCryptLevel :: forall (proxy :: * -> *). proxy ApplicationSecret -> CryptLevel
getCryptLevel proxy ApplicationSecret
_ = CryptLevel
CryptApplicationSecret

data RecordState = RecordState
    { RecordState -> Maybe Cipher
stCipher :: Maybe Cipher
    , RecordState -> Compression
stCompression :: Compression
    , RecordState -> CryptLevel
stCryptLevel :: CryptLevel
    , RecordState -> CryptState
stCryptState :: CryptState
    , RecordState -> MacState
stMacState :: MacState
    }
    deriving (Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
(Int -> RecordState -> ShowS)
-> (RecordState -> String)
-> ([RecordState] -> ShowS)
-> Show RecordState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RecordState -> ShowS
showsPrec :: Int -> RecordState -> ShowS
$cshow :: RecordState -> String
show :: RecordState -> String
$cshowList :: [RecordState] -> ShowS
showList :: [RecordState] -> ShowS
Show)

newtype RecordM a = RecordM
    { forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM
        :: RecordOptions
        -> RecordState
        -> Either TLSError (a, RecordState)
    }

instance Applicative RecordM where
    pure :: forall a. a -> RecordM a
pure a
a = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (a
a, RecordState
st)
    <*> :: forall a b. RecordM (a -> b) -> RecordM a -> RecordM b
(<*>) = RecordM (a -> b) -> RecordM a -> RecordM b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad RecordM where
    RecordM a
m1 >>= :: forall a b. RecordM a -> (a -> RecordM b) -> RecordM b
>>= a -> RecordM b
m2 = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
 -> RecordM b)
-> (RecordOptions
    -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
        case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m1 RecordOptions
opt RecordState
st of
            Left TLSError
err -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
            Right (a
a, RecordState
st2) -> RecordM b
-> RecordOptions -> RecordState -> Either TLSError (b, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (a -> RecordM b
m2 a
a) RecordOptions
opt RecordState
st2

instance Functor RecordM where
    fmap :: forall a b. (a -> b) -> RecordM a -> RecordM b
fmap a -> b
f RecordM a
m = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
 -> RecordM b)
-> (RecordOptions
    -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
        case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
            Left TLSError
err -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
            Right (a
a, RecordState
st2) -> (b, RecordState) -> Either TLSError (b, RecordState)
forall a b. b -> Either a b
Right (a -> b
f a
a, RecordState
st2)

getRecordOptions :: RecordM RecordOptions
getRecordOptions :: RecordM RecordOptions
getRecordOptions = (RecordOptions
 -> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError (RecordOptions, RecordState))
 -> RecordM RecordOptions)
-> (RecordOptions
    -> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st -> (RecordOptions, RecordState)
-> Either TLSError (RecordOptions, RecordState)
forall a b. b -> Either a b
Right (RecordOptions
opt, RecordState
st)

getRecordVersion :: RecordM Version
getRecordVersion :: RecordM Version
getRecordVersion = RecordOptions -> Version
recordVersion (RecordOptions -> Version)
-> RecordM RecordOptions -> RecordM Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordOptions
getRecordOptions

instance MonadState RecordState RecordM where
    put :: RecordState -> RecordM ()
put RecordState
x = (RecordOptions -> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError ((), RecordState))
 -> RecordM ())
-> (RecordOptions
    -> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> ((), RecordState) -> Either TLSError ((), RecordState)
forall a b. b -> Either a b
Right ((), RecordState
x)
    get :: RecordM RecordState
get = (RecordOptions
 -> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError (RecordState, RecordState))
 -> RecordM RecordState)
-> (RecordOptions
    -> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (RecordState, RecordState)
-> Either TLSError (RecordState, RecordState)
forall a b. b -> Either a b
Right (RecordState
st, RecordState
st)
    state :: forall a. (RecordState -> (a, RecordState)) -> RecordM a
state RecordState -> (a, RecordState)
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (RecordState -> (a, RecordState)
f RecordState
st)

instance MonadError TLSError RecordM where
    throwError :: forall a. TLSError -> RecordM a
throwError TLSError
e = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> TLSError -> Either TLSError (a, RecordState)
forall a b. a -> Either a b
Left TLSError
e
    catchError :: forall a. RecordM a -> (TLSError -> RecordM a) -> RecordM a
catchError RecordM a
m TLSError -> RecordM a
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
        case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
            Left TLSError
err -> RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (TLSError -> RecordM a
f TLSError
err) RecordOptions
opt RecordState
st
            Either TLSError (a, RecordState)
r -> Either TLSError (a, RecordState)
r

newRecordState :: RecordState
newRecordState :: RecordState
newRecordState =
    RecordState
        { stCipher :: Maybe Cipher
stCipher = Maybe Cipher
forall a. Maybe a
Nothing
        , stCompression :: Compression
stCompression = Compression
nullCompression
        , stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
CryptInitial
        , stCryptState :: CryptState
stCryptState = BulkState -> ByteString -> ByteString -> CryptState
CryptState BulkState
BulkStateUninitialized ByteString
B.empty ByteString
B.empty
        , stMacState :: MacState
stMacState = Word64 -> MacState
MacState Word64
0
        }

incrRecordState :: RecordState -> RecordState
incrRecordState :: RecordState -> RecordState
incrRecordState RecordState
ts = RecordState
ts{stMacState = MacState (ms + 1)}
  where
    (MacState Word64
ms) = RecordState -> MacState
stMacState RecordState
ts

setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV ByteString
iv RecordState
st = RecordState
st{stCryptState = (stCryptState st){cstIV = iv}}

withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression :: forall a. (Compression -> (Compression, a)) -> RecordM a
withCompression Compression -> (Compression, a)
f = do
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    let (Compression
nc, a
a) = Compression -> (Compression, a)
f (Compression -> (Compression, a))
-> Compression -> (Compression, a)
forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
st
    RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (RecordState -> RecordM ()) -> RecordState -> RecordM ()
forall a b. (a -> b) -> a -> b
$ RecordState
st{stCompression = nc}
    a -> RecordM a
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

computeDigest
    :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest :: Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
_ver RecordState
tstate Header
hdr ByteString
content = (ByteString
digest, RecordState -> RecordState
incrRecordState RecordState
tstate)
  where
    digest :: ByteString
digest = HMAC
macF (CryptState -> ByteString
cstMacSecret CryptState
cst) ByteString
msg
    cst :: CryptState
cst = RecordState -> CryptState
stCryptState RecordState
tstate
    cipher :: Cipher
cipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tstate
    hashA :: Hash
hashA = Cipher -> Hash
cipherHash Cipher
cipher
    encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 (Word64 -> ByteString) -> Word64 -> ByteString
forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence (MacState -> Word64) -> MacState -> Word64
forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tstate

    (HMAC
macF, ByteString
msg) = (Hash -> HMAC
hmac Hash
hashA, [ByteString] -> ByteString
B.concat [ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr, ByteString
content])

makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest Header
hdr ByteString
content = do
    Version
ver <- RecordM Version
getRecordVersion
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    let (ByteString
digest, RecordState
nstate) = Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
st Header
hdr ByteString
content
    RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put RecordState
nstate
    ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
digest

getBulk :: RecordM Bulk
getBulk :: RecordM Bulk
getBulk = Cipher -> Bulk
cipherBulk (Cipher -> Bulk) -> (RecordState -> Cipher) -> RecordState -> Bulk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher)
-> (RecordState -> Maybe Cipher) -> RecordState -> Cipher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> Maybe Cipher
stCipher (RecordState -> Bulk) -> RecordM RecordState -> RecordM Bulk
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get

getMacSequence :: RecordM Word64
getMacSequence :: RecordM Word64
getMacSequence = MacState -> Word64
msSequence (MacState -> Word64)
-> (RecordState -> MacState) -> RecordState -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> MacState
stMacState (RecordState -> Word64) -> RecordM RecordState -> RecordM Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get