{-# LANGUAGE FlexibleInstances #-}
-- |
-- Module      : Network.TLS.Handshake.Key
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- functions for RSA operations
--
module Network.TLS.Handshake.Key
    ( encryptRSA
    , signPrivate
    , decryptRSA
    , verifyPublic
    , generateDHE
    , generateECDHE
    , generateECDHEShared
    , generateFFDHE
    , generateFFDHEShared
    , versionCompatible
    , isDigitalSignaturePair
    , checkDigitalSignatureKey
    , getLocalPublicKey
    , satisfiesEcPredicate
    , logKey
    ) where

import Control.Monad.State.Strict

import qualified Data.ByteString as B

import Network.TLS.Handshake.State
import Network.TLS.State (withRNG, getVersion)
import Network.TLS.Crypto
import Network.TLS.Types
import Network.TLS.Context.Internal
import Network.TLS.Imports
import Network.TLS.Struct
import Network.TLS.X509

{- if the RSA encryption fails we just return an empty bytestring, and let the protocol
 - fail by itself; however it would be probably better to just report it since it's an internal problem.
 -}
encryptRSA :: Context -> ByteString -> IO ByteString
encryptRSA :: Context -> ByteString -> IO ByteString
encryptRSA Context
ctx ByteString
content = do
    PubKey
publicKey <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM PubKey
getRemotePublicKey
    forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
        Either KxError ByteString
v <- forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
PubKey -> ByteString -> r (Either KxError ByteString)
kxEncrypt PubKey
publicKey ByteString
content
        case Either KxError ByteString
v of
            Left KxError
err       -> forall a. HasCallStack => [Char] -> a
error ([Char]
"rsa encrypt failed: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show KxError
err)
            Right ByteString
econtent -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
econtent

signPrivate :: Context -> Role -> SignatureParams -> ByteString -> IO ByteString
signPrivate :: Context -> Role -> SignatureParams -> ByteString -> IO ByteString
signPrivate Context
ctx Role
_ SignatureParams
params ByteString
content = do
    (PubKey
publicKey, PrivKey
privateKey) <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys
    forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
        Either KxError ByteString
r <- forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
PrivKey
-> PubKey
-> SignatureParams
-> ByteString
-> r (Either KxError ByteString)
kxSign PrivKey
privateKey PubKey
publicKey SignatureParams
params ByteString
content
        case Either KxError ByteString
r of
            Left KxError
err       -> forall a. HasCallStack => [Char] -> a
error ([Char]
"sign failed: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show KxError
err)
            Right ByteString
econtent -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
econtent

decryptRSA :: Context -> ByteString -> IO (Either KxError ByteString)
decryptRSA :: Context -> ByteString -> IO (Either KxError ByteString)
decryptRSA Context
ctx ByteString
econtent = do
    (PubKey
_, PrivKey
privateKey) <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys
    forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
        Version
ver <- TLSSt Version
getVersion
        let cipher :: ByteString
cipher = if Version
ver forall a. Ord a => a -> a -> Bool
< Version
TLS10 then ByteString
econtent else Int -> ByteString -> ByteString
B.drop Int
2 ByteString
econtent
        forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
PrivKey -> ByteString -> r (Either KxError ByteString)
kxDecrypt PrivKey
privateKey ByteString
cipher

verifyPublic :: Context -> SignatureParams -> ByteString -> ByteString -> IO Bool
verifyPublic :: Context -> SignatureParams -> ByteString -> ByteString -> IO Bool
verifyPublic Context
ctx SignatureParams
params ByteString
econtent ByteString
sign = do
    PubKey
publicKey <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM PubKey
getRemotePublicKey
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ PubKey -> SignatureParams -> ByteString -> ByteString -> Bool
kxVerify PubKey
publicKey SignatureParams
params ByteString
econtent ByteString
sign

generateDHE :: Context -> DHParams -> IO (DHPrivate, DHPublic)
generateDHE :: Context -> DHParams -> IO (DHPrivate, DHPublic)
generateDHE Context
ctx DHParams
dhp = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
DHParams -> r (DHPrivate, DHPublic)
dhGenerateKeyPair DHParams
dhp

generateECDHE :: Context -> Group -> IO (GroupPrivate, GroupPublic)
generateECDHE :: Context -> Group -> IO (GroupPrivate, GroupPublic)
generateECDHE Context
ctx Group
grp = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
Group -> r (GroupPrivate, GroupPublic)
groupGenerateKeyPair Group
grp

generateECDHEShared :: Context -> GroupPublic -> IO (Maybe (GroupPublic, GroupKey))
generateECDHEShared :: Context -> GroupPublic -> IO (Maybe (GroupPublic, GroupKey))
generateECDHEShared Context
ctx GroupPublic
pub = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
GroupPublic -> r (Maybe (GroupPublic, GroupKey))
groupGetPubShared GroupPublic
pub

generateFFDHE :: Context -> Group -> IO (DHParams, DHPrivate, DHPublic)
generateFFDHE :: Context -> Group -> IO (DHParams, DHPrivate, DHPublic)
generateFFDHE Context
ctx Group
grp = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
Group -> r (DHParams, DHPrivate, DHPublic)
dhGroupGenerateKeyPair Group
grp

generateFFDHEShared :: Context -> Group -> DHPublic -> IO (Maybe (DHPublic, DHKey))
generateFFDHEShared :: Context -> Group -> DHPublic -> IO (Maybe (DHPublic, DHKey))
generateFFDHEShared Context
ctx Group
grp DHPublic
pub = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG forall a b. (a -> b) -> a -> b
$ forall (r :: * -> *).
MonadRandom r =>
Group -> DHPublic -> r (Maybe (DHPublic, DHKey))
dhGroupGetPubShared Group
grp DHPublic
pub

isDigitalSignatureKey :: PubKey -> Bool
isDigitalSignatureKey :: PubKey -> Bool
isDigitalSignatureKey (PubKeyRSA PublicKey
_)      = Bool
True
isDigitalSignatureKey (PubKeyDSA PublicKey
_)      = Bool
True
isDigitalSignatureKey (PubKeyEC  PubKeyEC
_)      = Bool
True
isDigitalSignatureKey (PubKeyEd25519 PublicKey
_)  = Bool
True
isDigitalSignatureKey (PubKeyEd448   PublicKey
_)  = Bool
True
isDigitalSignatureKey PubKey
_                  = Bool
False

versionCompatible :: PubKey -> Version -> Bool
versionCompatible :: PubKey -> Version -> Bool
versionCompatible (PubKeyRSA PublicKey
_)       Version
_ = Bool
True
versionCompatible (PubKeyDSA PublicKey
_)       Version
v = Version
v forall a. Ord a => a -> a -> Bool
<= Version
TLS12
versionCompatible (PubKeyEC PubKeyEC
_)        Version
v = Version
v forall a. Ord a => a -> a -> Bool
>= Version
TLS10
versionCompatible (PubKeyEd25519 PublicKey
_)   Version
v = Version
v forall a. Ord a => a -> a -> Bool
>= Version
TLS12
versionCompatible (PubKeyEd448 PublicKey
_)     Version
v = Version
v forall a. Ord a => a -> a -> Bool
>= Version
TLS12
versionCompatible PubKey
_                   Version
_ = Bool
False

-- | Test whether the argument is a public key supported for signature at the
-- specified TLS version.  This also accepts a key for RSA encryption.  This
-- test is performed by clients or servers before verifying a remote
-- Certificate Verify.
checkDigitalSignatureKey :: MonadIO m => Version -> PubKey -> m ()
checkDigitalSignatureKey :: forall (m :: * -> *). MonadIO m => Version -> PubKey -> m ()
checkDigitalSignatureKey Version
usedVersion PubKey
key = do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PubKey -> Bool
isDigitalSignatureKey PubKey
key) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ [Char] -> AlertDescription -> TLSError
Error_Protocol [Char]
"unsupported remote public key type" AlertDescription
HandshakeFailure
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PubKey
key PubKey -> Version -> Bool
`versionCompatible` Version
usedVersion) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ [Char] -> AlertDescription -> TLSError
Error_Protocol (forall a. Show a => a -> [Char]
show Version
usedVersion forall a. [a] -> [a] -> [a]
++ [Char]
" has no support for " forall a. [a] -> [a] -> [a]
++ PubKey -> [Char]
pubkeyType PubKey
key) AlertDescription
IllegalParameter

-- | Test whether the argument is matching key pair supported for signature.
-- This also accepts material for RSA encryption.  This test is performed by
-- servers or clients before using a credential from the local configuration.
isDigitalSignaturePair :: (PubKey, PrivKey) -> Bool
isDigitalSignaturePair :: (PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey, PrivKey)
keyPair =
    case (PubKey, PrivKey)
keyPair of
        (PubKeyRSA      PublicKey
_, PrivKeyRSA      PrivateKey
_)  -> Bool
True
        (PubKeyDSA      PublicKey
_, PrivKeyDSA      PrivateKey
_)  -> Bool
True
        (PubKeyEC       PubKeyEC
_, PrivKeyEC       PrivKeyEC
k)  -> PrivKeyEC -> Bool
kxSupportedPrivKeyEC PrivKeyEC
k
        (PubKeyEd25519  PublicKey
_, PrivKeyEd25519  SecretKey
_)  -> Bool
True
        (PubKeyEd448    PublicKey
_, PrivKeyEd448    SecretKey
_)  -> Bool
True
        (PubKey, PrivKey)
_                                      -> Bool
False

getLocalPublicKey :: MonadIO m => Context -> m PubKey
getLocalPublicKey :: forall (m :: * -> *). MonadIO m => Context -> m PubKey
getLocalPublicKey Context
ctx =
    forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys)

-- | Test whether the public key satisfies a predicate about the elliptic curve.
-- When the public key is not suitable for ECDSA, like RSA for instance, the
-- predicate is not used and the result is 'True'.
satisfiesEcPredicate :: (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate :: (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate Group -> Bool
p (PubKeyEC PubKeyEC
ecPub) =
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Group -> Bool
p forall a b. (a -> b) -> a -> b
$ PubKeyEC -> Maybe Group
findEllipticCurveGroup PubKeyEC
ecPub
satisfiesEcPredicate Group -> Bool
_ PubKey
_                = Bool
True

----------------------------------------------------------------

class LogLabel a where
    labelAndKey :: a -> (String, ByteString)

instance LogLabel MasterSecret where
    labelAndKey :: MasterSecret -> ([Char], ByteString)
labelAndKey (MasterSecret ByteString
key) = ([Char]
"CLIENT_RANDOM", ByteString
key)

instance LogLabel (ClientTrafficSecret EarlySecret) where
    labelAndKey :: ClientTrafficSecret EarlySecret -> ([Char], ByteString)
labelAndKey (ClientTrafficSecret ByteString
key) = ([Char]
"CLIENT_EARLY_TRAFFIC_SECRET", ByteString
key)

instance LogLabel (ServerTrafficSecret HandshakeSecret) where
    labelAndKey :: ServerTrafficSecret HandshakeSecret -> ([Char], ByteString)
labelAndKey (ServerTrafficSecret ByteString
key) = ([Char]
"SERVER_HANDSHAKE_TRAFFIC_SECRET", ByteString
key)

instance LogLabel (ClientTrafficSecret HandshakeSecret) where
    labelAndKey :: ClientTrafficSecret HandshakeSecret -> ([Char], ByteString)
labelAndKey (ClientTrafficSecret ByteString
key) = ([Char]
"CLIENT_HANDSHAKE_TRAFFIC_SECRET", ByteString
key)

instance LogLabel (ServerTrafficSecret ApplicationSecret) where
    labelAndKey :: ServerTrafficSecret ApplicationSecret -> ([Char], ByteString)
labelAndKey (ServerTrafficSecret ByteString
key) = ([Char]
"SERVER_TRAFFIC_SECRET_0", ByteString
key)

instance LogLabel (ClientTrafficSecret ApplicationSecret) where
    labelAndKey :: ClientTrafficSecret ApplicationSecret -> ([Char], ByteString)
labelAndKey (ClientTrafficSecret ByteString
key) = ([Char]
"CLIENT_TRAFFIC_SECRET_0", ByteString
key)

-- NSS Key Log Format
-- See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
logKey :: LogLabel a => Context -> a -> IO ()
logKey :: forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx a
logkey = do
    Maybe HandshakeState
mhst <- forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    case Maybe HandshakeState
mhst of
      Maybe HandshakeState
Nothing  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just HandshakeState
hst -> do
          let cr :: ByteString
cr = ClientRandom -> ByteString
unClientRandom forall a b. (a -> b) -> a -> b
$ HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hst
              ([Char]
label,ByteString
key) = forall a. LogLabel a => a -> ([Char], ByteString)
labelAndKey a
logkey
          Context -> [Char] -> IO ()
ctxKeyLogger Context
ctx forall a b. (a -> b) -> a -> b
$ [Char]
label forall a. [a] -> [a] -> [a]
++ [Char]
" " forall a. [a] -> [a] -> [a]
++ ByteString -> [Char]
dump ByteString
cr forall a. [a] -> [a] -> [a]
++ [Char]
" " forall a. [a] -> [a] -> [a]
++ ByteString -> [Char]
dump ByteString
key
  where
    dump :: ByteString -> [Char]
dump = forall a. [a] -> [a]
init forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
tail forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
showBytesHex