{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}
module Network.TLS.Record.State
( CryptState(..)
, CryptLevel(..)
, HasCryptLevel(..)
, MacState(..)
, RecordOptions(..)
, RecordState(..)
, newRecordState
, incrRecordState
, RecordM
, runRecordM
, getRecordOptions
, getRecordVersion
, setRecordIV
, withCompression
, computeDigest
, makeDigest
, getBulk
, getMacSequence
) where
import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import Network.TLS.Types
import qualified Data.ByteString as B
data CryptState = CryptState
{ CryptState -> BulkState
cstKey :: !BulkState
, CryptState -> ByteString
cstIV :: !ByteString
, CryptState -> ByteString
cstMacSecret :: !ByteString
} deriving (Int -> CryptState -> ShowS
[CryptState] -> ShowS
CryptState -> String
(Int -> CryptState -> ShowS)
-> (CryptState -> String)
-> ([CryptState] -> ShowS)
-> Show CryptState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptState] -> ShowS
$cshowList :: [CryptState] -> ShowS
show :: CryptState -> String
$cshow :: CryptState -> String
showsPrec :: Int -> CryptState -> ShowS
$cshowsPrec :: Int -> CryptState -> ShowS
Show)
newtype MacState = MacState
{ MacState -> Word64
msSequence :: Word64
} deriving (Int -> MacState -> ShowS
[MacState] -> ShowS
MacState -> String
(Int -> MacState -> ShowS)
-> (MacState -> String) -> ([MacState] -> ShowS) -> Show MacState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MacState] -> ShowS
$cshowList :: [MacState] -> ShowS
show :: MacState -> String
$cshow :: MacState -> String
showsPrec :: Int -> MacState -> ShowS
$cshowsPrec :: Int -> MacState -> ShowS
Show)
data RecordOptions = RecordOptions
{ RecordOptions -> Version
recordVersion :: Version
, RecordOptions -> Bool
recordTLS13 :: Bool
}
data CryptLevel
= CryptInitial
| CryptMasterSecret
| CryptEarlySecret
| CryptHandshakeSecret
| CryptApplicationSecret
deriving (CryptLevel -> CryptLevel -> Bool
(CryptLevel -> CryptLevel -> Bool)
-> (CryptLevel -> CryptLevel -> Bool) -> Eq CryptLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CryptLevel -> CryptLevel -> Bool
$c/= :: CryptLevel -> CryptLevel -> Bool
== :: CryptLevel -> CryptLevel -> Bool
$c== :: CryptLevel -> CryptLevel -> Bool
Eq,Int -> CryptLevel -> ShowS
[CryptLevel] -> ShowS
CryptLevel -> String
(Int -> CryptLevel -> ShowS)
-> (CryptLevel -> String)
-> ([CryptLevel] -> ShowS)
-> Show CryptLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptLevel] -> ShowS
$cshowList :: [CryptLevel] -> ShowS
show :: CryptLevel -> String
$cshow :: CryptLevel -> String
showsPrec :: Int -> CryptLevel -> ShowS
$cshowsPrec :: Int -> CryptLevel -> ShowS
Show)
class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel :: proxy EarlySecret -> CryptLevel
getCryptLevel proxy EarlySecret
_ = CryptLevel
CryptEarlySecret
instance HasCryptLevel HandshakeSecret where getCryptLevel :: proxy HandshakeSecret -> CryptLevel
getCryptLevel proxy HandshakeSecret
_ = CryptLevel
CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where getCryptLevel :: proxy ApplicationSecret -> CryptLevel
getCryptLevel proxy ApplicationSecret
_ = CryptLevel
CryptApplicationSecret
data RecordState = RecordState
{ RecordState -> Maybe Cipher
stCipher :: Maybe Cipher
, RecordState -> Compression
stCompression :: Compression
, RecordState -> CryptLevel
stCryptLevel :: !CryptLevel
, RecordState -> CryptState
stCryptState :: !CryptState
, RecordState -> MacState
stMacState :: !MacState
} deriving (Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
(Int -> RecordState -> ShowS)
-> (RecordState -> String)
-> ([RecordState] -> ShowS)
-> Show RecordState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecordState] -> ShowS
$cshowList :: [RecordState] -> ShowS
show :: RecordState -> String
$cshow :: RecordState -> String
showsPrec :: Int -> RecordState -> ShowS
$cshowsPrec :: Int -> RecordState -> ShowS
Show)
newtype RecordM a = RecordM { RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM :: RecordOptions
-> RecordState
-> Either TLSError (a, RecordState) }
instance Applicative RecordM where
pure :: a -> RecordM a
pure = a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return
<*> :: RecordM (a -> b) -> RecordM a -> RecordM b
(<*>) = RecordM (a -> b) -> RecordM a -> RecordM b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad RecordM where
return :: a -> RecordM a
return a
a = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a)
-> (RecordOptions
-> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (a
a, RecordState
st)
RecordM a
m1 >>= :: RecordM a -> (a -> RecordM b) -> RecordM b
>>= a -> RecordM b
m2 = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b)
-> (RecordOptions
-> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m1 RecordOptions
opt RecordState
st of
Left TLSError
err -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
Right (a
a, RecordState
st2) -> RecordM b
-> RecordOptions -> RecordState -> Either TLSError (b, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (a -> RecordM b
m2 a
a) RecordOptions
opt RecordState
st2
instance Functor RecordM where
fmap :: (a -> b) -> RecordM a -> RecordM b
fmap a -> b
f RecordM a
m = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b)
-> (RecordOptions
-> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
Left TLSError
err -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
Right (a
a, RecordState
st2) -> (b, RecordState) -> Either TLSError (b, RecordState)
forall a b. b -> Either a b
Right (a -> b
f a
a, RecordState
st2)
getRecordOptions :: RecordM RecordOptions
getRecordOptions :: RecordM RecordOptions
getRecordOptions = (RecordOptions
-> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
-> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions)
-> (RecordOptions
-> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st -> (RecordOptions, RecordState)
-> Either TLSError (RecordOptions, RecordState)
forall a b. b -> Either a b
Right (RecordOptions
opt, RecordState
st)
getRecordVersion :: RecordM Version
getRecordVersion :: RecordM Version
getRecordVersion = RecordOptions -> Version
recordVersion (RecordOptions -> Version)
-> RecordM RecordOptions -> RecordM Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordOptions
getRecordOptions
instance MonadState RecordState RecordM where
put :: RecordState -> RecordM ()
put RecordState
x = (RecordOptions -> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
-> RecordState -> Either TLSError ((), RecordState))
-> RecordM ())
-> (RecordOptions
-> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> ((), RecordState) -> Either TLSError ((), RecordState)
forall a b. b -> Either a b
Right ((), RecordState
x)
get :: RecordM RecordState
get = (RecordOptions
-> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
-> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState)
-> (RecordOptions
-> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (RecordState, RecordState)
-> Either TLSError (RecordState, RecordState)
forall a b. b -> Either a b
Right (RecordState
st, RecordState
st)
#if MIN_VERSION_mtl(2,1,0)
state :: (RecordState -> (a, RecordState)) -> RecordM a
state RecordState -> (a, RecordState)
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a)
-> (RecordOptions
-> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (RecordState -> (a, RecordState)
f RecordState
st)
#endif
instance MonadError TLSError RecordM where
throwError :: TLSError -> RecordM a
throwError TLSError
e = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a)
-> (RecordOptions
-> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> TLSError -> Either TLSError (a, RecordState)
forall a b. a -> Either a b
Left TLSError
e
catchError :: RecordM a -> (TLSError -> RecordM a) -> RecordM a
catchError RecordM a
m TLSError -> RecordM a
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a)
-> (RecordOptions
-> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
Left TLSError
err -> RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (TLSError -> RecordM a
f TLSError
err) RecordOptions
opt RecordState
st
Either TLSError (a, RecordState)
r -> Either TLSError (a, RecordState)
r
newRecordState :: RecordState
newRecordState :: RecordState
newRecordState = RecordState :: Maybe Cipher
-> Compression
-> CryptLevel
-> CryptState
-> MacState
-> RecordState
RecordState
{ stCipher :: Maybe Cipher
stCipher = Maybe Cipher
forall a. Maybe a
Nothing
, stCompression :: Compression
stCompression = Compression
nullCompression
, stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
CryptInitial
, stCryptState :: CryptState
stCryptState = BulkState -> ByteString -> ByteString -> CryptState
CryptState BulkState
BulkStateUninitialized ByteString
B.empty ByteString
B.empty
, stMacState :: MacState
stMacState = Word64 -> MacState
MacState Word64
0
}
incrRecordState :: RecordState -> RecordState
incrRecordState :: RecordState -> RecordState
incrRecordState RecordState
ts = RecordState
ts { stMacState :: MacState
stMacState = Word64 -> MacState
MacState (Word64
ms Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) }
where (MacState Word64
ms) = RecordState -> MacState
stMacState RecordState
ts
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV ByteString
iv RecordState
st = RecordState
st { stCryptState :: CryptState
stCryptState = (RecordState -> CryptState
stCryptState RecordState
st) { cstIV :: ByteString
cstIV = ByteString
iv } }
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression Compression -> (Compression, a)
f = do
RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
let (Compression
nc, a
a) = Compression -> (Compression, a)
f (Compression -> (Compression, a))
-> Compression -> (Compression, a)
forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
st
RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (RecordState -> RecordM ()) -> RecordState -> RecordM ()
forall a b. (a -> b) -> a -> b
$ RecordState
st { stCompression :: Compression
stCompression = Compression
nc }
a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest :: Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
tstate Header
hdr ByteString
content = (ByteString
digest, RecordState -> RecordState
incrRecordState RecordState
tstate)
where digest :: ByteString
digest = HMAC
macF (CryptState -> ByteString
cstMacSecret CryptState
cst) ByteString
msg
cst :: CryptState
cst = RecordState -> CryptState
stCryptState RecordState
tstate
cipher :: Cipher
cipher = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tstate
hashA :: Hash
hashA = Cipher -> Hash
cipherHash Cipher
cipher
encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 (Word64 -> ByteString) -> Word64 -> ByteString
forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence (MacState -> Word64) -> MacState -> Word64
forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tstate
(HMAC
macF, ByteString
msg)
| Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS10 = (Hash -> HMAC
macSSL Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeaderNoVer Header
hdr, ByteString
content ])
| Bool
otherwise = (Hash -> HMAC
hmac Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr, ByteString
content ])
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest Header
hdr ByteString
content = do
Version
ver <- RecordM Version
getRecordVersion
RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
let (ByteString
digest, RecordState
nstate) = Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
st Header
hdr ByteString
content
RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put RecordState
nstate
ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
digest
getBulk :: RecordM Bulk
getBulk :: RecordM Bulk
getBulk = Cipher -> Bulk
cipherBulk (Cipher -> Bulk) -> (RecordState -> Cipher) -> RecordState -> Bulk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher)
-> (RecordState -> Maybe Cipher) -> RecordState -> Cipher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> Maybe Cipher
stCipher (RecordState -> Bulk) -> RecordM RecordState -> RecordM Bulk
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
getMacSequence :: RecordM Word64
getMacSequence :: RecordM Word64
getMacSequence = MacState -> Word64
msSequence (MacState -> Word64)
-> (RecordState -> MacState) -> RecordState -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> MacState
stMacState (RecordState -> Word64) -> RecordM RecordState -> RecordM Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get