{-# LANGUAGE CPP #-}
module Network.TLS.Context
(
TLSParams
, Context(..)
, Hooks(..)
, Established(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, withLog
, ctxWithHooks
, contextModifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
, Information(..)
, contextGetInformation
, contextNew
, contextNewOnHandle
#ifdef INCLUDE_NETWORK
, contextNewOnSocket
#endif
, contextHookSetHandshakeRecv
, contextHookSetHandshake13Recv
, contextHookSetCertificateRecv
, contextHookSetLogging
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
, tls13orLater
) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.PostHandshake (requestCertificateServer, postHandshakeAuthClientWith, postHandshakeAuthServerWith)
import Network.TLS.X509
import Network.TLS.RNG
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)
class TLSParams a where
getTLSCommonParams :: a -> CommonParams
getTLSRole :: a -> Role
doHandshake :: a -> Context -> IO ()
doHandshakeWith :: a -> Context -> Handshake -> IO ()
doRequestCertificate :: a -> Context -> IO Bool
doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()
instance TLSParams ClientParams where
getTLSCommonParams cparams = ( clientSupported cparams
, clientShared cparams
, clientDebug cparams
)
getTLSRole _ = ClientRole
doHandshake = handshakeClient
doHandshakeWith = handshakeClientWith
doRequestCertificate _ _ = return False
doPostHandshakeAuthWith = postHandshakeAuthClientWith
instance TLSParams ServerParams where
getTLSCommonParams sparams = ( serverSupported sparams
, serverShared sparams
, serverDebug sparams
)
getTLSRole _ = ServerRole
doHandshake = handshakeServer
doHandshakeWith = handshakeServerWith
doRequestCertificate = requestCertificateServer
doPostHandshakeAuthWith = postHandshakeAuthServerWith
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
=> backend
-> params
-> m Context
contextNew backend params = liftIO $ do
initializeBackend backend
let (supported, shared, debug) = getTLSCommonParams params
seed <- case debugSeed debug of
Nothing -> do seed <- seedNew
debugPrintSeed debug seed
return seed
Just determ -> return determ
let rng = newStateRNG seed
let role = getTLSRole params
st = newTLSState rng role
stvar <- newMVar st
eof <- newIORef False
established <- newIORef NotEstablished
stats <- newIORef newMeasurement
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
hs <- newMVar Nothing
as <- newIORef []
crs <- newIORef []
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
return Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxHandshake = hs
, ctxDoHandshake = doHandshake params
, ctxDoHandshakeWith = doHandshakeWith params
, ctxDoRequestCertificate = doRequestCertificate params
, ctxDoPostHandshakeAuthWith = doPostHandshakeAuthWith params
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
, ctxLockState = lockState
, ctxPendingActions = as
, ctxCertRequests = crs
, ctxKeyLogger = debugKeyLogger debug
}
contextNewOnHandle :: (MonadIO m, TLSParams params)
=> Handle
-> params
-> m Context
contextNewOnHandle = contextNew
{-# DEPRECATED contextNewOnHandle "use contextNew" #-}
#ifdef INCLUDE_NETWORK
contextNewOnSocket :: (MonadIO m, TLSParams params)
=> Socket
-> params
-> m Context
contextNewOnSocket sock params = contextNew sock params
{-# DEPRECATED contextNewOnSocket "use contextNew" #-}
#endif
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvHandshake13 = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging context loggingCallbacks =
contextModifyHooks context (\hooks -> hooks { hookLogging = loggingCallbacks })