{-# LANGUAGE CPP #-}
module Network.TLS.Context
(
TLSParams
, Context(..)
, Hooks(..)
, Established(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, withLog
, ctxWithHooks
, contextModifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
, Information(..)
, contextGetInformation
, contextNew
, contextNewOnHandle
#ifdef INCLUDE_NETWORK
, contextNewOnSocket
#endif
, contextHookSetHandshakeRecv
, contextHookSetHandshake13Recv
, contextHookSetCertificateRecv
, contextHookSetLogging
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
, tls13orLater
, getFinished
, getPeerFinished
) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Record.Layer
import Network.TLS.Record.Reading
import Network.TLS.Record.Writing
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.PostHandshake (requestCertificateServer, postHandshakeAuthClientWith, postHandshakeAuthServerWith)
import Network.TLS.X509
import Network.TLS.RNG
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)
class TLSParams a where
getTLSCommonParams :: a -> CommonParams
getTLSRole :: a -> Role
doHandshake :: a -> Context -> IO ()
doHandshakeWith :: a -> Context -> Handshake -> IO ()
doRequestCertificate :: a -> Context -> IO Bool
doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()
instance TLSParams ClientParams where
getTLSCommonParams :: ClientParams -> CommonParams
getTLSCommonParams ClientParams
cparams = ( ClientParams -> Supported
clientSupported ClientParams
cparams
, ClientParams -> Shared
clientShared ClientParams
cparams
, ClientParams -> DebugParams
clientDebug ClientParams
cparams
)
getTLSRole :: ClientParams -> Role
getTLSRole ClientParams
_ = Role
ClientRole
doHandshake :: ClientParams -> Context -> IO ()
doHandshake = ClientParams -> Context -> IO ()
handshakeClient
doHandshakeWith :: ClientParams -> Context -> Handshake -> IO ()
doHandshakeWith = ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith
doRequestCertificate :: ClientParams -> Context -> IO Bool
doRequestCertificate ClientParams
_ Context
_ = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
doPostHandshakeAuthWith :: ClientParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ClientParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthClientWith
instance TLSParams ServerParams where
getTLSCommonParams :: ServerParams -> CommonParams
getTLSCommonParams ServerParams
sparams = ( ServerParams -> Supported
serverSupported ServerParams
sparams
, ServerParams -> Shared
serverShared ServerParams
sparams
, ServerParams -> DebugParams
serverDebug ServerParams
sparams
)
getTLSRole :: ServerParams -> Role
getTLSRole ServerParams
_ = Role
ServerRole
doHandshake :: ServerParams -> Context -> IO ()
doHandshake = ServerParams -> Context -> IO ()
handshakeServer
doHandshakeWith :: ServerParams -> Context -> Handshake -> IO ()
doHandshakeWith = ServerParams -> Context -> Handshake -> IO ()
handshakeServerWith
doRequestCertificate :: ServerParams -> Context -> IO Bool
doRequestCertificate = ServerParams -> Context -> IO Bool
requestCertificateServer
doPostHandshakeAuthWith :: ServerParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ServerParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthServerWith
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
=> backend
-> params
-> m Context
contextNew :: backend -> params -> m Context
contextNew backend
backend params
params = IO Context -> m Context
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Context -> m Context) -> IO Context -> m Context
forall a b. (a -> b) -> a -> b
$ do
backend -> IO ()
forall a. HasBackend a => a -> IO ()
initializeBackend backend
backend
let (Supported
supported, Shared
shared, DebugParams
debug) = params -> CommonParams
forall a. TLSParams a => a -> CommonParams
getTLSCommonParams params
params
Seed
seed <- case DebugParams -> Maybe Seed
debugSeed DebugParams
debug of
Maybe Seed
Nothing -> do Seed
seed <- IO Seed
forall (randomly :: * -> *). MonadRandom randomly => randomly Seed
seedNew
DebugParams -> Seed -> IO ()
debugPrintSeed DebugParams
debug Seed
seed
Seed -> IO Seed
forall (m :: * -> *) a. Monad m => a -> m a
return Seed
seed
Just Seed
determ -> Seed -> IO Seed
forall (m :: * -> *) a. Monad m => a -> m a
return Seed
determ
let rng :: StateRNG
rng = Seed -> StateRNG
newStateRNG Seed
seed
let role :: Role
role = params -> Role
forall a. TLSParams a => a -> Role
getTLSRole params
params
st :: TLSState
st = StateRNG -> Role -> TLSState
newTLSState StateRNG
rng Role
role
MVar TLSState
stvar <- TLSState -> IO (MVar TLSState)
forall a. a -> IO (MVar a)
newMVar TLSState
st
IORef Bool
eof <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
IORef Established
established <- Established -> IO (IORef Established)
forall a. a -> IO (IORef a)
newIORef Established
NotEstablished
IORef Measurement
stats <- Measurement -> IO (IORef Measurement)
forall a. a -> IO (IORef a)
newIORef Measurement
newMeasurement
IORef Bool
sslv2Compat <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef (Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ServerRole)
IORef Bool
needEmptyPacket <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
IORef Hooks
hooks <- Hooks -> IO (IORef Hooks)
forall a. a -> IO (IORef a)
newIORef Hooks
defaultHooks
MVar RecordState
tx <- RecordState -> IO (MVar RecordState)
forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
MVar RecordState
rx <- RecordState -> IO (MVar RecordState)
forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
MVar (Maybe HandshakeState)
hs <- Maybe HandshakeState -> IO (MVar (Maybe HandshakeState))
forall a. a -> IO (MVar a)
newMVar Maybe HandshakeState
forall a. Maybe a
Nothing
IORef [PendingAction]
as <- [PendingAction] -> IO (IORef [PendingAction])
forall a. a -> IO (IORef a)
newIORef []
IORef [Handshake13]
crs <- [Handshake13] -> IO (IORef [Handshake13])
forall a. a -> IO (IORef a)
newIORef []
MVar ()
lockWrite <- () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
MVar ()
lockRead <- () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
MVar ()
lockState <- () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
IORef (Maybe FinishedData)
finished <- Maybe FinishedData -> IO (IORef (Maybe FinishedData))
forall a. a -> IO (IORef a)
newIORef Maybe FinishedData
forall a. Maybe a
Nothing
IORef (Maybe FinishedData)
peerFinished <- Maybe FinishedData -> IO (IORef (Maybe FinishedData))
forall a. a -> IO (IORef a)
newIORef Maybe FinishedData
forall a. Maybe a
Nothing
let ctx :: Context
ctx = Context :: forall bytes.
Monoid bytes =>
Backend
-> Supported
-> Shared
-> MVar TLSState
-> IORef Measurement
-> IORef Bool
-> IORef Established
-> IORef Bool
-> IORef Bool
-> Maybe Int
-> MVar RecordState
-> MVar RecordState
-> MVar (Maybe HandshakeState)
-> (Context -> IO ())
-> (Context -> Handshake -> IO ())
-> (Context -> IO Bool)
-> (Context -> Handshake13 -> IO ())
-> IORef Hooks
-> MVar ()
-> MVar ()
-> MVar ()
-> IORef [PendingAction]
-> IORef [Handshake13]
-> (String -> IO ())
-> RecordLayer bytes
-> HandshakeSync
-> Bool
-> IORef (Maybe FinishedData)
-> IORef (Maybe FinishedData)
-> Context
Context
{ ctxConnection :: Backend
ctxConnection = backend -> Backend
forall a. HasBackend a => a -> Backend
getBackend backend
backend
, ctxShared :: Shared
ctxShared = Shared
shared
, ctxSupported :: Supported
ctxSupported = Supported
supported
, ctxState :: MVar TLSState
ctxState = MVar TLSState
stvar
, ctxFragmentSize :: Maybe Int
ctxFragmentSize = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
16384
, ctxTxState :: MVar RecordState
ctxTxState = MVar RecordState
tx
, ctxRxState :: MVar RecordState
ctxRxState = MVar RecordState
rx
, ctxHandshake :: MVar (Maybe HandshakeState)
ctxHandshake = MVar (Maybe HandshakeState)
hs
, ctxDoHandshake :: Context -> IO ()
ctxDoHandshake = params -> Context -> IO ()
forall a. TLSParams a => a -> Context -> IO ()
doHandshake params
params
, ctxDoHandshakeWith :: Context -> Handshake -> IO ()
ctxDoHandshakeWith = params -> Context -> Handshake -> IO ()
forall a. TLSParams a => a -> Context -> Handshake -> IO ()
doHandshakeWith params
params
, ctxDoRequestCertificate :: Context -> IO Bool
ctxDoRequestCertificate = params -> Context -> IO Bool
forall a. TLSParams a => a -> Context -> IO Bool
doRequestCertificate params
params
, ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
ctxDoPostHandshakeAuthWith = params -> Context -> Handshake13 -> IO ()
forall a. TLSParams a => a -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith params
params
, ctxMeasurement :: IORef Measurement
ctxMeasurement = IORef Measurement
stats
, ctxEOF_ :: IORef Bool
ctxEOF_ = IORef Bool
eof
, ctxEstablished_ :: IORef Established
ctxEstablished_ = IORef Established
established
, ctxSSLv2ClientHello :: IORef Bool
ctxSSLv2ClientHello = IORef Bool
sslv2Compat
, ctxNeedEmptyPacket :: IORef Bool
ctxNeedEmptyPacket = IORef Bool
needEmptyPacket
, ctxHooks :: IORef Hooks
ctxHooks = IORef Hooks
hooks
, ctxLockWrite :: MVar ()
ctxLockWrite = MVar ()
lockWrite
, ctxLockRead :: MVar ()
ctxLockRead = MVar ()
lockRead
, ctxLockState :: MVar ()
ctxLockState = MVar ()
lockState
, ctxPendingActions :: IORef [PendingAction]
ctxPendingActions = IORef [PendingAction]
as
, ctxCertRequests :: IORef [Handshake13]
ctxCertRequests = IORef [Handshake13]
crs
, ctxKeyLogger :: String -> IO ()
ctxKeyLogger = DebugParams -> String -> IO ()
debugKeyLogger DebugParams
debug
, ctxRecordLayer :: RecordLayer FinishedData
ctxRecordLayer = RecordLayer FinishedData
recordLayer
, ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync Context -> ClientState -> IO ()
forall (m :: * -> *) p p. Monad m => p -> p -> m ()
syncNoOp Context -> ServerState -> IO ()
forall (m :: * -> *) p p. Monad m => p -> p -> m ()
syncNoOp
, ctxQUICMode :: Bool
ctxQUICMode = Bool
False
, ctxFinished :: IORef (Maybe FinishedData)
ctxFinished = IORef (Maybe FinishedData)
finished
, ctxPeerFinished :: IORef (Maybe FinishedData)
ctxPeerFinished = IORef (Maybe FinishedData)
peerFinished
}
syncNoOp :: p -> p -> m ()
syncNoOp p
_ p
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
recordLayer :: RecordLayer FinishedData
recordLayer = RecordLayer :: forall bytes.
(Record Plaintext -> IO (Either TLSError bytes))
-> (Record Plaintext -> IO (Either TLSError bytes))
-> (bytes -> IO ())
-> (Bool -> Int -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
-> RecordLayer bytes
RecordLayer
{ recordEncode :: Record Plaintext -> IO (Either TLSError FinishedData)
recordEncode = Context -> Record Plaintext -> IO (Either TLSError FinishedData)
encodeRecord Context
ctx
, recordEncode13 :: Record Plaintext -> IO (Either TLSError FinishedData)
recordEncode13 = Context -> Record Plaintext -> IO (Either TLSError FinishedData)
encodeRecord13 Context
ctx
, recordSendBytes :: FinishedData -> IO ()
recordSendBytes = Context -> FinishedData -> IO ()
sendBytes Context
ctx
, recordRecv :: Bool -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv = Context -> Bool -> Int -> IO (Either TLSError (Record Plaintext))
recvRecord Context
ctx
, recordRecv13 :: IO (Either TLSError (Record Plaintext))
recordRecv13 = Context -> IO (Either TLSError (Record Plaintext))
recvRecord13 Context
ctx
}
Context -> IO Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
ctx
contextNewOnHandle :: (MonadIO m, TLSParams params)
=> Handle
-> params
-> m Context
contextNewOnHandle :: Handle -> params -> m Context
contextNewOnHandle = Handle -> params -> m Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew
{-# DEPRECATED contextNewOnHandle "use contextNew" #-}
#ifdef INCLUDE_NETWORK
contextNewOnSocket :: (MonadIO m, TLSParams params)
=> Socket
-> params
-> m Context
contextNewOnSocket :: Socket -> params -> m Context
contextNewOnSocket Socket
sock params
params = Socket -> params -> m Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Socket
sock params
params
{-# DEPRECATED contextNewOnSocket "use contextNew" #-}
#endif
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv Context
context Handshake -> IO Handshake
f =
Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvHandshake :: Handshake -> IO Handshake
hookRecvHandshake = Handshake -> IO Handshake
f })
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv Context
context Handshake13 -> IO Handshake13
f =
Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvHandshake13 :: Handshake13 -> IO Handshake13
hookRecvHandshake13 = Handshake13 -> IO Handshake13
f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv Context
context CertificateChain -> IO ()
f =
Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvCertificates :: CertificateChain -> IO ()
hookRecvCertificates = CertificateChain -> IO ()
f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging Context
context Logging
loggingCallbacks =
Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookLogging :: Logging
hookLogging = Logging
loggingCallbacks })
getFinished :: Context -> IO (Maybe FinishedData)
getFinished :: Context -> IO (Maybe FinishedData)
getFinished = IORef (Maybe FinishedData) -> IO (Maybe FinishedData)
forall a. IORef a -> IO a
readIORef (IORef (Maybe FinishedData) -> IO (Maybe FinishedData))
-> (Context -> IORef (Maybe FinishedData))
-> Context
-> IO (Maybe FinishedData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> IORef (Maybe FinishedData)
ctxFinished
getPeerFinished :: Context -> IO (Maybe FinishedData)
getPeerFinished :: Context -> IO (Maybe FinishedData)
getPeerFinished = IORef (Maybe FinishedData) -> IO (Maybe FinishedData)
forall a. IORef a -> IO a
readIORef (IORef (Maybe FinishedData) -> IO (Maybe FinishedData))
-> (Context -> IORef (Maybe FinishedData))
-> Context
-> IO (Maybe FinishedData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> IORef (Maybe FinishedData)
ctxPeerFinished