module Network.TLS.Context
(
TLSParams
, Context(..)
, Hooks(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, withLog
, ctxWithHooks
, contextModifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
, Information(..)
, contextGetInformation
, contextNew
, contextNewOnHandle
#ifdef INCLUDE_NETWORK
, contextNewOnSocket
#endif
, contextHookSetHandshakeRecv
, contextHookSetCertificateRecv
, contextHookSetLogging
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..))
import Network.TLS.Credentials
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.X509
import Network.TLS.RNG
import Data.Maybe (isJust)
import Control.Concurrent.MVar
import Control.Monad.State
import Data.IORef
import Data.Monoid (mappend)
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)
class TLSParams a where
getTLSCommonParams :: a -> CommonParams
getTLSRole :: a -> Role
getCiphers :: a -> Credentials -> [Cipher]
doHandshake :: a -> Context -> IO ()
doHandshakeWith :: a -> Context -> Handshake -> IO ()
instance TLSParams ClientParams where
getTLSCommonParams cparams = ( clientSupported cparams
, clientShared cparams
, clientDebug cparams
)
getTLSRole _ = ClientRole
getCiphers cparams _ = supportedCiphers $ clientSupported cparams
doHandshake = handshakeClient
doHandshakeWith = handshakeClientWith
instance TLSParams ServerParams where
getTLSCommonParams sparams = ( serverSupported sparams
, serverShared sparams
, serverDebug sparams
)
getTLSRole _ = ServerRole
getCiphers sparams extraCreds = filter authorizedCKE (supportedCiphers $ serverSupported sparams)
where authorizedCKE cipher =
case cipherKeyExchange cipher of
CipherKeyExchange_RSA -> canEncryptRSA
CipherKeyExchange_DH_Anon -> canDHE
CipherKeyExchange_DHE_RSA -> canSignRSA && canDHE
CipherKeyExchange_DHE_DSS -> canSignDSS && canDHE
CipherKeyExchange_ECDHE_RSA -> canSignRSA
CipherKeyExchange_ECDHE_ECDSA -> False
CipherKeyExchange_DH_DSS -> False
CipherKeyExchange_DH_RSA -> False
CipherKeyExchange_ECDH_ECDSA -> False
CipherKeyExchange_ECDH_RSA -> False
canDHE = isJust $ serverDHEParams sparams
canSignDSS = SignatureDSS `elem` signingAlgs
canSignRSA = SignatureRSA `elem` signingAlgs
canEncryptRSA = isJust $ credentialsFindForDecrypting creds
signingAlgs = credentialsListSigningAlgorithms creds
serverCreds = sharedCredentials $ serverShared sparams
creds = extraCreds `mappend` serverCreds
doHandshake = handshakeServer
doHandshakeWith = handshakeServerWith
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
=> backend
-> params
-> m Context
contextNew backend params = liftIO $ do
initializeBackend backend
let (supported, shared, debug) = getTLSCommonParams params
seed <- case debugSeed debug of
Nothing -> do seed <- seedNew
debugPrintSeed debug $ seed
return seed
Just determ -> return determ
let rng = newStateRNG seed
let role = getTLSRole params
st = newTLSState rng role
ciphers = getCiphers params
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
hs <- newMVar Nothing
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
return $ Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
, ctxCiphers = ciphers
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxHandshake = hs
, ctxDoHandshake = doHandshake params
, ctxDoHandshakeWith = doHandshakeWith params
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
, ctxLockState = lockState
}
contextNewOnHandle :: (MonadIO m, TLSParams params)
=> Handle
-> params
-> m Context
contextNewOnHandle handle params = contextNew handle params
#ifdef INCLUDE_NETWORK
contextNewOnSocket :: (MonadIO m, TLSParams params)
=> Socket
-> params
-> m Context
contextNewOnSocket sock params = contextNew sock params
#endif
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging context loggingCallbacks =
contextModifyHooks context (\hooks -> hooks { hookLogging = loggingCallbacks })