module Network.Transport.TCP
(
createTransport
, TCPParameters(..)
, defaultTCPParameters
, createTransportExposeInternals
, TransportInternals(..)
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, ControlHeader(..)
, ConnectionRequestResponse(..)
, firstNonReservedLightweightConnectionId
, firstNonReservedHeavyweightConnectionId
, socketToEndPoint
, LightweightConnectionId
) where
import Prelude hiding
( mapM_
#if ! MIN_VERSION_base(4,6,0)
, catch
#endif
)
import Network.Transport
import Network.Transport.TCP.Internal
( forkServer
, recvWithLength
, recvInt32
, tryCloseSocket
)
import Network.Transport.Internal
( encodeInt32
, prependLength
, mapIOException
, tryIO
, tryToEnum
, void
, timeoutMaybe
, asyncWhenCancelled
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, ServiceName
, Socket
, getAddrInfo
, socket
, addrFamily
, addrAddress
, SocketType(Stream)
, defaultProtocol
, setSocketOption
, SocketOption(ReuseAddr)
, connect
, sOMAXCONN
, AddrInfo
)
#ifdef USE_MOCK_NETWORK
import Network.Transport.TCP.Mock.Socket.ByteString (sendMany)
#else
import Network.Socket.ByteString (sendMany)
#endif
import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar
( MVar
, newMVar
, modifyMVar
, modifyMVar_
, readMVar
, putMVar
, newEmptyMVar
, withMVar
)
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless, join)
import Control.Exception
( IOException
, SomeException
, AsyncException
, handle
, throw
, throwIO
, try
, bracketOnError
, fromException
, catch
)
import Data.IORef (IORef, newIORef, writeIORef, readIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Bits (shiftL, (.|.))
import Data.Word (Word32)
import Data.Set (Set)
import qualified Data.Set as Set
( empty
, insert
, elems
, singleton
, null
, delete
, member
)
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
import qualified Data.Accessor.Container as DAC (mapMaybe)
import Data.Foldable (forM_, mapM_)
data TCPTransport = TCPTransport
{ transportHost :: !N.HostName
, transportPort :: !N.ServiceName
, transportState :: !(MVar TransportState)
, transportParams :: !TCPParameters
}
data TransportState =
TransportValid !ValidTransportState
| TransportClosed
data ValidTransportState = ValidTransportState
{ _localEndPoints :: !(Map EndPointAddress LocalEndPoint)
, _nextEndPointId :: !EndPointId
}
data LocalEndPoint = LocalEndPoint
{ localAddress :: !EndPointAddress
, localChannel :: !(Chan Event)
, localState :: !(MVar LocalEndPointState)
}
data LocalEndPointState =
LocalEndPointValid !ValidLocalEndPointState
| LocalEndPointClosed
data ValidLocalEndPointState = ValidLocalEndPointState
{
_localNextConnOutId :: !LightweightConnectionId
, _nextConnInId :: !HeavyweightConnectionId
, _localConnections :: !(Map EndPointAddress RemoteEndPoint)
}
data RemoteEndPoint = RemoteEndPoint
{ remoteAddress :: !EndPointAddress
, remoteState :: !(MVar RemoteState)
, remoteId :: !HeavyweightConnectionId
, remoteScheduled :: !(Chan (IO ()))
}
data RequestedBy = RequestedByUs | RequestedByThem
deriving (Eq, Show)
data RemoteState =
RemoteEndPointInvalid !(TransportError ConnectErrorCode)
| RemoteEndPointInit !(MVar ()) !(MVar ()) !RequestedBy
| RemoteEndPointValid !ValidRemoteEndPointState
| RemoteEndPointClosing !(MVar ()) !ValidRemoteEndPointState
| RemoteEndPointClosed
| RemoteEndPointFailed !IOException
data ValidRemoteEndPointState = ValidRemoteEndPointState
{ _remoteOutgoing :: !Int
, _remoteIncoming :: !(Set LightweightConnectionId)
, _remoteMaxIncoming :: !LightweightConnectionId
, _remoteNextConnOutId :: !LightweightConnectionId
, remoteSocket :: !N.Socket
, remoteSendLock :: !(MVar ())
}
type EndPointId = Word32
type EndPointPair = (LocalEndPoint, RemoteEndPoint)
type LightweightConnectionId = Word32
type HeavyweightConnectionId = Word32
data ControlHeader =
CreatedNewConnection
| CloseConnection
| CloseSocket
deriving (Enum, Bounded, Show)
data ConnectionRequestResponse =
ConnectionRequestAccepted
| ConnectionRequestInvalid
| ConnectionRequestCrossed
deriving (Enum, Bounded, Show)
data TCPParameters = TCPParameters {
tcpBacklog :: Int
, tcpReuseServerAddr :: Bool
, tcpReuseClientAddr :: Bool
}
data TransportInternals = TransportInternals
{
transportThread :: ThreadId
, socketBetween :: EndPointAddress
-> EndPointAddress
-> IO N.Socket
}
createTransport :: N.HostName
-> N.ServiceName
-> TCPParameters
-> IO (Either IOException Transport)
createTransport host port params =
either Left (Right . fst) <$> createTransportExposeInternals host port params
createTransportExposeInternals
:: N.HostName
-> N.ServiceName
-> TCPParameters
-> IO (Either IOException (Transport, TransportInternals))
createTransportExposeInternals host port params = do
state <- newMVar . TransportValid $ ValidTransportState
{ _localEndPoints = Map.empty
, _nextEndPointId = 0
}
let transport = TCPTransport { transportState = state
, transportHost = host
, transportPort = port
, transportParams = params
}
tryIO $ bracketOnError (forkServer
host
port
(tcpBacklog params)
(tcpReuseServerAddr params)
(terminationHandler transport)
(handleConnectionRequest transport))
killThread
(mkTransport transport)
where
mkTransport :: TCPTransport
-> ThreadId
-> IO (Transport, TransportInternals)
mkTransport transport tid = return
( Transport
{ newEndPoint = apiNewEndPoint transport
, closeTransport = let evs = [ EndPointClosed
, throw $ userError "Transport closed"
] in
apiCloseTransport transport (Just tid) evs
}
, TransportInternals
{ transportThread = tid
, socketBetween = internalSocketBetween transport
}
)
terminationHandler :: TCPTransport -> SomeException -> IO ()
terminationHandler transport ex = do
let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex))
, throw $ userError "Transport closed"
]
apiCloseTransport transport Nothing evs
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
tcpBacklog = N.sOMAXCONN
, tcpReuseServerAddr = True
, tcpReuseClientAddr = True
}
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport transport mTransportThread evs =
asyncWhenCancelled return $ do
mTSt <- modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> return (TransportClosed, Just vst)
TransportClosed -> return (TransportClosed, Nothing)
forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints)
forM_ mTransportThread killThread
apiNewEndPoint :: TCPTransport
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint transport =
try . asyncWhenCancelled closeEndPoint $ do
ourEndPoint <- createLocalEndPoint transport
return EndPoint
{ receive = readChan (localChannel ourEndPoint)
, address = localAddress ourEndPoint
, connect = apiConnect (transportParams transport) ourEndPoint
, closeEndPoint = let evs = [ EndPointClosed
, throw $ userError "Endpoint closed"
] in
apiCloseEndPoint transport evs ourEndPoint
, newMulticastGroup = return . Left $ newMulticastGroupError
, resolveMulticastGroup = return . Left . const resolveMulticastGroupError
}
where
newMulticastGroupError =
TransportError NewMulticastGroupUnsupported "Multicast not supported"
resolveMulticastGroupError =
TransportError ResolveMulticastGroupUnsupported "Multicast not supported"
apiConnect :: TCPParameters
-> LocalEndPoint
-> EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect params ourEndPoint theirAddress _reliability hints =
try . asyncWhenCancelled close $
if localAddress ourEndPoint == theirAddress
then connectToSelf ourEndPoint
else do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, connId) <-
createConnectionTo params ourEndPoint theirAddress hints
connAlive <- newIORef True
return Connection
{ send = apiSend (ourEndPoint, theirEndPoint) connId connAlive
, close = apiClose (ourEndPoint, theirEndPoint) connId connAlive
}
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
apiClose (ourEndPoint, theirEndPoint) connId connAlive =
void . tryIO . asyncWhenCancelled return $ do
mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst -> do
alive <- readIORef connAlive
if alive
then do
writeIORef connAlive False
act <- schedule theirEndPoint $
sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId]
return ( RemoteEndPointValid
. (remoteOutgoing ^: (\x -> x 1))
$ vst
, Just act
)
else
return (RemoteEndPointValid vst, Nothing)
_ ->
return (st, Nothing)
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
closeIfUnused (ourEndPoint, theirEndPoint)
apiSend :: EndPointPair
-> LightweightConnectionId
-> IORef Bool
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
try . mapIOException sendFailed $ do
act <- withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
RemoteEndPointValid vst -> do
alive <- readIORef connAlive
if alive
then schedule theirEndPoint $
sendOn vst (encodeInt32 connId : prependLength payload)
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointClosing _ _ -> do
alive <- readIORef connAlive
if alive
then relyViolation (ourEndPoint, theirEndPoint) "apiSend"
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointClosed -> do
alive <- readIORef connAlive
if alive
then relyViolation (ourEndPoint, theirEndPoint) "apiSend"
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointFailed err -> do
alive <- readIORef connAlive
if alive
then throwIO $ TransportError SendFailed (show err)
else throwIO $ TransportError SendClosed "Connection closed"
runScheduledAction (ourEndPoint, theirEndPoint) act
where
sendFailed = TransportError SendFailed . show
apiCloseEndPoint :: TCPTransport
-> [Event]
-> LocalEndPoint
-> IO ()
apiCloseEndPoint transport evs ourEndPoint =
asyncWhenCancelled return $ do
removeLocalEndPoint transport ourEndPoint
mOurState <- modifyMVar (localState ourEndPoint) $ \st ->
case st of
LocalEndPointValid vst ->
return (LocalEndPointClosed, Just vst)
LocalEndPointClosed ->
return (LocalEndPointClosed, Nothing)
forM_ mOurState $ \vst -> do
forM_ (vst ^. localConnections) tryCloseRemoteSocket
forM_ evs $ writeChan (localChannel ourEndPoint)
where
tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
tryCloseRemoteSocket theirEndPoint = do
let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
mAct <- modifyMVar (remoteState theirEndPoint) $ \st ->
case st of
RemoteEndPointInvalid _ ->
return (st, Nothing)
RemoteEndPointInit resolved _ _ -> do
putMVar resolved ()
return (closed, Nothing)
RemoteEndPointValid vst -> do
act <- schedule theirEndPoint $ do
tryIO $ sendOn vst [ encodeInt32 CloseSocket
, encodeInt32 (vst ^. remoteMaxIncoming)
]
tryCloseSocket (remoteSocket vst)
return (closed, Just act)
RemoteEndPointClosing resolved vst -> do
putMVar resolved ()
act <- schedule theirEndPoint $ tryCloseSocket (remoteSocket vst)
return (closed, Just act)
RemoteEndPointClosed ->
return (st, Nothing)
RemoteEndPointFailed err ->
return (RemoteEndPointFailed err, Nothing)
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
handleConnectionRequest :: TCPTransport -> N.Socket -> IO ()
handleConnectionRequest transport sock = handle handleException $ do
ourEndPointId <- recvInt32 sock
theirAddress <- EndPointAddress . BS.concat <$> recvWithLength sock
let ourAddress = encodeEndPointAddress (transportHost transport)
(transportPort transport)
ourEndPointId
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportValid vst ->
case vst ^. localEndPointAt ourAddress of
Nothing -> do
sendMany sock [encodeInt32 ConnectionRequestInvalid]
throwIO $ userError "handleConnectionRequest: Invalid endpoint"
Just ourEndPoint ->
return ourEndPoint
TransportClosed ->
throwIO $ userError "Transport closed"
void . forkIO $ go ourEndPoint theirAddress
where
go :: LocalEndPoint -> EndPointAddress -> IO ()
go ourEndPoint theirAddress = do
mEndPoint <- handle ((>> return Nothing) . handleException) $ do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, isNew) <-
findRemoteEndPoint ourEndPoint theirAddress RequestedByThem
if not isNew
then do
tryIO $ sendMany sock [encodeInt32 ConnectionRequestCrossed]
tryCloseSocket sock
return Nothing
else do
sendLock <- newMVar ()
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, remoteSendLock = sendLock
, _remoteOutgoing = 0
, _remoteIncoming = Set.empty
, _remoteMaxIncoming = 0
, _remoteNextConnOutId = firstNonReservedLightweightConnectionId
}
sendMany sock [encodeInt32 ConnectionRequestAccepted]
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
return (Just theirEndPoint)
forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint
handleException :: SomeException -> IO ()
handleException ex = do
tryCloseSocket sock
rethrowIfAsync (fromException ex)
rethrowIfAsync :: Maybe AsyncException -> IO ()
rethrowIfAsync = mapM_ throwIO
handleIncomingMessages :: EndPointPair -> IO ()
handleIncomingMessages (ourEndPoint, theirEndPoint) = do
mSock <- withMVar theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (init)"
RemoteEndPointValid ep ->
return . Just $ remoteSocket ep
RemoteEndPointClosing _ ep ->
return . Just $ remoteSocket ep
RemoteEndPointClosed ->
return Nothing
RemoteEndPointFailed _ ->
return Nothing
forM_ mSock $ \sock ->
tryIO (go sock) >>= either (prematureExit sock) return
where
go :: N.Socket -> IO ()
go sock = do
lcid <- recvInt32 sock :: IO LightweightConnectionId
if lcid >= firstNonReservedLightweightConnectionId
then do
readMessage sock lcid
go sock
else
case tryToEnum (fromIntegral lcid) of
Just CreatedNewConnection -> do
recvInt32 sock >>= createdNewConnection
go sock
Just CloseConnection -> do
recvInt32 sock >>= closeConnection
go sock
Just CloseSocket -> do
didClose <- recvInt32 sock >>= closeSocket sock
unless didClose $ go sock
Nothing ->
throwIO $ userError "Invalid control request"
createdNewConnection :: LightweightConnectionId -> IO ()
createdNewConnection lcid = do
modifyMVar_ theirState $ \st -> do
vst <- case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (init)"
RemoteEndPointValid vst ->
return ( (remoteIncoming ^: Set.insert lcid)
$ (remoteMaxIncoming ^= lcid)
vst
)
RemoteEndPointClosing resolved vst -> do
putMVar resolved ()
return ( (remoteIncoming ^= Set.singleton lcid)
. (remoteMaxIncoming ^= lcid)
$ vst
)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"createNewConnection (closed)"
return (RemoteEndPointValid vst)
writeChan ourChannel (ConnectionOpened (connId lcid) ReliableOrdered theirAddr)
closeConnection :: LightweightConnectionId -> IO ()
closeConnection lcid = do
modifyMVar_ theirState $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)"
RemoteEndPointValid vst -> do
unless (Set.member lcid (vst ^. remoteIncoming)) $
throwIO $ userError "Invalid CloseConnection"
return ( RemoteEndPointValid
. (remoteIncoming ^: Set.delete lcid)
$ vst
)
RemoteEndPointClosing _ _ ->
throwIO $ userError "Invalid CloseConnection request"
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)"
writeChan ourChannel (ConnectionClosed (connId lcid))
closeSocket :: N.Socket -> LightweightConnectionId -> IO Bool
closeSocket sock lastReceivedId = do
mAct <- modifyMVar theirState $ \st -> do
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (init)"
RemoteEndPointValid vst -> do
forM_ (Set.elems $ vst ^. remoteIncoming) $
writeChan ourChannel . ConnectionClosed . connId
let vst' = remoteIncoming ^= Set.empty $ vst
if vst' ^. remoteOutgoing > 0 || lastReceivedId < lastSentId vst
then
return (RemoteEndPointValid vst', Nothing)
else do
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
act <- schedule theirEndPoint $ do
tryIO $ sendOn vst' [ encodeInt32 CloseSocket
, encodeInt32 (vst ^. remoteMaxIncoming)
]
tryCloseSocket sock
return (RemoteEndPointClosed, Just act)
RemoteEndPointClosing resolved vst -> do
if lastReceivedId < lastSentId vst
then do
return (RemoteEndPointClosing resolved vst, Nothing)
else do
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
act <- schedule theirEndPoint $ tryCloseSocket sock
putMVar resolved ()
return (RemoteEndPointClosed, Just act)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (closed)"
case mAct of
Nothing -> return False
Just act -> do
runScheduledAction (ourEndPoint, theirEndPoint) act
return True
readMessage :: N.Socket -> LightweightConnectionId -> IO ()
readMessage sock lcid =
recvWithLength sock >>= writeChan ourChannel . Received (connId lcid)
ourChannel = localChannel ourEndPoint
theirState = remoteState theirEndPoint
theirAddr = remoteAddress theirEndPoint
prematureExit :: N.Socket -> IOException -> IO ()
prematureExit sock err = do
tryCloseSocket sock
modifyMVar_ theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointValid _ -> do
let code = EventConnectionLost (remoteAddress theirEndPoint)
writeChan ourChannel . ErrorEvent $ TransportError code (show err)
return (RemoteEndPointFailed err)
RemoteEndPointClosing resolved _ -> do
putMVar resolved ()
return (RemoteEndPointFailed err)
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointFailed err' ->
return (RemoteEndPointFailed err')
connId :: LightweightConnectionId -> ConnectionId
connId = createConnectionId (remoteId theirEndPoint)
lastSentId :: ValidRemoteEndPointState -> LightweightConnectionId
lastSentId vst =
if vst ^. remoteNextConnOutId == firstNonReservedLightweightConnectionId
then 0
else (vst ^. remoteNextConnOutId) 1
createConnectionTo :: TCPParameters
-> LocalEndPoint
-> EndPointAddress
-> ConnectHints
-> IO (RemoteEndPoint, LightweightConnectionId)
createConnectionTo params ourEndPoint theirAddress hints = go
where
go = do
(theirEndPoint, isNew) <- mapIOException connectFailed $
findRemoteEndPoint ourEndPoint theirAddress RequestedByUs
if isNew
then do
forkIO . handle absorbAllExceptions $
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints
go
else do
mapIOException connectFailed $ do
act <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst -> do
let connId = vst ^. remoteNextConnOutId
act <- schedule theirEndPoint $ do
sendOn vst [encodeInt32 CreatedNewConnection, encodeInt32 connId]
return connId
return ( RemoteEndPointValid
$ remoteNextConnOutId ^= connId + 1
$ vst
, act
)
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointFailed err ->
throwIO err
_ ->
relyViolation (ourEndPoint, theirEndPoint) "createConnectionTo"
connId <- runScheduledAction (ourEndPoint, theirEndPoint) act
return (theirEndPoint, connId)
connectFailed :: IOException -> TransportError ConnectErrorCode
connectFailed = TransportError ConnectFailed . show
absorbAllExceptions :: SomeException -> IO ()
absorbAllExceptions _ex =
return ()
setupRemoteEndPoint :: TCPParameters -> EndPointPair -> ConnectHints -> IO ()
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints = do
result <- socketToEndPoint ourAddress
theirAddress
(tcpReuseClientAddr params)
(connectTimeout hints)
didAccept <- case result of
Right (sock, ConnectionRequestAccepted) -> do
sendLock <- newMVar ()
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, remoteSendLock = sendLock
, _remoteOutgoing = 0
, _remoteIncoming = Set.empty
, _remoteMaxIncoming = 0
, _remoteNextConnOutId = firstNonReservedLightweightConnectionId
}
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
return True
Right (sock, ConnectionRequestInvalid) -> do
let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint"
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
tryCloseSocket sock
return False
Right (sock, ConnectionRequestCrossed) -> do
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit _ crossed _ ->
putMVar crossed ()
RemoteEndPointFailed ex ->
throwIO ex
_ ->
relyViolation (ourEndPoint, theirEndPoint) "setupRemoteEndPoint: Crossed"
tryCloseSocket sock
return False
Left err -> do
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
return False
when didAccept $ handleIncomingMessages (ourEndPoint, theirEndPoint)
where
ourAddress = localAddress ourEndPoint
theirAddress = remoteAddress theirEndPoint
invalidAddress = TransportError ConnectNotFound
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (ourEndPoint, theirEndPoint) = do
mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst ->
if vst ^. remoteOutgoing == 0 && Set.null (vst ^. remoteIncoming)
then do
resolved <- newEmptyMVar
act <- schedule theirEndPoint $
sendOn vst [ encodeInt32 CloseSocket
, encodeInt32 (vst ^. remoteMaxIncoming)
]
return (RemoteEndPointClosing resolved vst, Just act)
else
return (RemoteEndPointValid vst, Nothing)
_ ->
return (st, Nothing)
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken ourEndPoint theirAddress = do
mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointValid vst ->
return (vst ^. localConnectionTo theirAddress)
LocalEndPointClosed ->
throwIO $ TransportError ConnectFailed "Endpoint closed"
forM_ mTheirEndPoint $ \theirEndPoint ->
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInvalid _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
RemoteEndPointFailed _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
connectToSelf :: LocalEndPoint
-> IO Connection
connectToSelf ourEndPoint = do
connAlive <- newIORef True
lconnId <- mapIOException connectFailed $ getLocalNextConnOutId ourEndPoint
let connId = createConnectionId heavyweightSelfConnectionId lconnId
writeChan ourChan $
ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint)
return Connection
{ send = selfSend connAlive connId
, close = selfClose connAlive connId
}
where
selfSend :: IORef Bool
-> ConnectionId
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
selfSend connAlive connId msg =
try . withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
if alive
then writeChan ourChan (Received connId msg)
else throwIO $ TransportError SendClosed "Connection closed"
LocalEndPointClosed ->
throwIO $ TransportError SendFailed "Endpoint closed"
selfClose :: IORef Bool -> ConnectionId -> IO ()
selfClose connAlive connId =
withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
when alive $ do
writeChan ourChan (ConnectionClosed connId)
writeIORef connAlive False
LocalEndPointClosed ->
return ()
ourChan = localChannel ourEndPoint
ourState = localState ourEndPoint
connectFailed = TransportError ConnectFailed . show
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (ourEndPoint, theirEndPoint) newState =
modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit resolved _ _ -> do
putMVar resolved ()
case newState of
RemoteEndPointClosed ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
return newState
RemoteEndPointFailed ex ->
throwIO ex
_ ->
relyViolation (ourEndPoint, theirEndPoint) "resolveInit"
getLocalNextConnOutId :: LocalEndPoint -> IO LightweightConnectionId
getLocalNextConnOutId ourEndpoint =
modifyMVar (localState ourEndpoint) $ \st -> case st of
LocalEndPointValid vst -> do
let connId = vst ^. localNextConnOutId
return ( LocalEndPointValid
. (localNextConnOutId ^= connId + 1)
$ vst
, connId)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
createLocalEndPoint :: TCPTransport -> IO LocalEndPoint
createLocalEndPoint transport = do
chan <- newChan
state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState
{ _localNextConnOutId = firstNonReservedLightweightConnectionId
, _localConnections = Map.empty
, _nextConnInId = firstNonReservedHeavyweightConnectionId
}
modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> do
let ix = vst ^. nextEndPointId
let addr = encodeEndPointAddress (transportHost transport)
(transportPort transport)
ix
let localEndPoint = LocalEndPoint { localAddress = addr
, localChannel = chan
, localState = state
}
return ( TransportValid
. (localEndPointAt addr ^= Just localEndPoint)
. (nextEndPointId ^= ix + 1)
$ vst
, localEndPoint
)
TransportClosed ->
throwIO (TransportError NewEndPointFailed "Transport closed")
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint) =
modifyMVar_ ourState $ \st -> case st of
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing ->
return st
Just remoteEndPoint' ->
if remoteId remoteEndPoint' == remoteId theirEndPoint
then return
( LocalEndPointValid
. (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing)
$ vst
)
else return st
LocalEndPointClosed ->
return LocalEndPointClosed
where
ourState = localState ourEndPoint
theirAddress = remoteAddress theirEndPoint
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint transport ourEndPoint =
modifyMVar_ (transportState transport) $ \st -> case st of
TransportValid vst ->
return ( TransportValid
. (localEndPointAt (localAddress ourEndPoint) ^= Nothing)
$ vst
)
TransportClosed ->
return TransportClosed
findRemoteEndPoint
:: LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint ourEndPoint theirAddress findOrigin = go
where
go = do
(theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of
LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of
Just theirEndPoint ->
return (st, (theirEndPoint, False))
Nothing -> do
resolved <- newEmptyMVar
crossed <- newEmptyMVar
theirState <- newMVar (RemoteEndPointInit resolved crossed findOrigin)
scheduled <- newChan
let theirEndPoint = RemoteEndPoint
{ remoteAddress = theirAddress
, remoteState = theirState
, remoteId = vst ^. nextConnInId
, remoteScheduled = scheduled
}
return ( LocalEndPointValid
. (localConnectionTo theirAddress ^= Just theirEndPoint)
. (nextConnInId ^: (+ 1))
$ vst
, (theirEndPoint, True)
)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
if isNew
then
return (theirEndPoint, True)
else do
let theirState = remoteState theirEndPoint
snapshot <- modifyMVar theirState $ \st -> case st of
RemoteEndPointValid vst ->
case findOrigin of
RequestedByUs -> do
let st' = RemoteEndPointValid
. (remoteOutgoing ^: (+ 1))
$ vst
return (st', st')
RequestedByThem ->
return (st, st)
_ ->
return (st, st)
case snapshot of
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointInit resolved crossed initOrigin ->
case (findOrigin, initOrigin) of
(RequestedByUs, RequestedByUs) ->
readMVar resolved >> go
(RequestedByUs, RequestedByThem) ->
readMVar resolved >> go
(RequestedByThem, RequestedByUs) ->
if ourAddress > theirAddress
then do
readMVar crossed
return (theirEndPoint, True)
else
return (theirEndPoint, False)
(RequestedByThem, RequestedByThem) ->
throwIO $ userError "Already connected"
RemoteEndPointValid _ ->
return (theirEndPoint, False)
RemoteEndPointClosing resolved _ ->
readMVar resolved >> go
RemoteEndPointClosed ->
go
RemoteEndPointFailed err ->
throwIO err
ourState = localState ourEndPoint
ourAddress = localAddress ourEndPoint
sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn vst bs = withMVar (remoteSendLock vst) $ \() ->
sendMany (remoteSocket vst) bs
type Action a = MVar (Either SomeException a)
schedule :: RemoteEndPoint -> IO a -> IO (Action a)
schedule theirEndPoint act = do
mvar <- newEmptyMVar
writeChan (remoteScheduled theirEndPoint) $
catch (act >>= putMVar mvar . Right) (putMVar mvar . Left)
return mvar
runScheduledAction :: EndPointPair -> Action a -> IO a
runScheduledAction (ourEndPoint, theirEndPoint) mvar = do
join $ readChan (remoteScheduled theirEndPoint)
ma <- readMVar mvar
case ma of
Right a -> return a
Left e -> do
forM_ (fromException e) $ \ioe ->
modifyMVar_ (remoteState theirEndPoint) $ \st ->
case st of
RemoteEndPointValid vst -> handleIOException ioe vst
_ -> return (RemoteEndPointFailed ioe)
throwIO e
where
handleIOException :: IOException
-> ValidRemoteEndPointState
-> IO RemoteState
handleIOException ex vst = do
tryCloseSocket (remoteSocket vst)
let code = EventConnectionLost (remoteAddress theirEndPoint)
err = TransportError code (show ex)
writeChan (localChannel ourEndPoint) $ ErrorEvent err
return (RemoteEndPointFailed ex)
socketToEndPoint :: EndPointAddress
-> EndPointAddress
-> Bool
-> Maybe Int
-> IO (Either (TransportError ConnectErrorCode)
(N.Socket, ConnectionRequestResponse))
socketToEndPoint (EndPointAddress ourAddress) theirAddress reuseAddr timeout =
try $ do
(host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of
Nothing -> throwIO (failed . userError $ "Could not parse")
Just dec -> return dec
addr:_ <- mapIOException invalidAddress $
N.getAddrInfo Nothing (Just host) (Just port)
bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do
when reuseAddr $
mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1
mapIOException invalidAddress $
timeoutMaybe timeout timeoutError $
N.connect sock (N.addrAddress addr)
response <- mapIOException failed $ do
sendMany sock (encodeInt32 theirEndPointId : prependLength [ourAddress])
recvInt32 sock
case tryToEnum response of
Nothing -> throwIO (failed . userError $ "Unexpected response")
Just r -> return (sock, r)
where
createSocket :: N.AddrInfo -> IO N.Socket
createSocket addr = mapIOException insufficientResources $
N.socket (N.addrFamily addr) N.Stream N.defaultProtocol
invalidAddress = TransportError ConnectNotFound . show
insufficientResources = TransportError ConnectInsufficientResources . show
failed = TransportError ConnectFailed . show
timeoutError = TransportError ConnectTimeout "Timed out"
encodeEndPointAddress :: N.HostName
-> N.ServiceName
-> EndPointId
-> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
host ++ ":" ++ port ++ ":" ++ show ix
decodeEndPointAddress :: EndPointAddress
-> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
[host, port, endPointIdStr] ->
case reads endPointIdStr of
[(endPointId, "")] -> Just (host, port, endPointId)
_ -> Nothing
_ ->
Nothing
createConnectionId :: HeavyweightConnectionId
-> LightweightConnectionId
-> ConnectionId
createConnectionId hcid lcid =
(fromIntegral hcid `shiftL` 32) .|. fromIntegral lcid
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
where
go accs _ [] = accs
go ([] : accs) 0 xs = reverse xs : accs
go (acc : accs) n (x:xs) =
if p x then go ([] : acc : accs) (n 1) xs
else go ((x : acc) : accs) n xs
go _ _ _ = error "Bug in splitMaxFromEnd"
internalSocketBetween :: TCPTransport
-> EndPointAddress
-> EndPointAddress
-> IO N.Socket
internalSocketBetween transport ourAddress theirAddress = do
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportClosed ->
throwIO $ userError "Transport closed"
TransportValid vst ->
case vst ^. localEndPointAt ourAddress of
Nothing -> throwIO $ userError "Local endpoint not found"
Just ep -> return ep
theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing -> throwIO $ userError "Remote endpoint not found"
Just ep -> return ep
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit _ _ _ ->
throwIO $ userError "Remote endpoint not yet initialized"
RemoteEndPointValid vst ->
return $ remoteSocket vst
RemoteEndPointClosing _ vst ->
return $ remoteSocket vst
RemoteEndPointClosed ->
throwIO $ userError "Remote endpoint closed"
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointFailed err ->
throwIO err
firstNonReservedLightweightConnectionId :: LightweightConnectionId
firstNonReservedLightweightConnectionId = 1024
heavyweightSelfConnectionId :: HeavyweightConnectionId
heavyweightSelfConnectionId = 0
firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId = 1
localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint)
localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es })
nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid })
localNextConnOutId :: Accessor ValidLocalEndPointState LightweightConnectionId
localNextConnOutId = accessor _localNextConnOutId (\cix st -> st { _localNextConnOutId = cix })
localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = accessor _localConnections (\es st -> st { _localConnections = es })
nextConnInId :: Accessor ValidLocalEndPointState HeavyweightConnectionId
nextConnInId = accessor _nextConnInId (\rid st -> st { _nextConnInId = rid })
remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs })
remoteIncoming :: Accessor ValidRemoteEndPointState (Set LightweightConnectionId)
remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs })
remoteMaxIncoming :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteMaxIncoming = accessor _remoteMaxIncoming (\lcid st -> st { _remoteMaxIncoming = lcid })
remoteNextConnOutId :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteNextConnOutId = accessor _remoteNextConnOutId (\cix st -> st { _remoteNextConnOutId = cix })
localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr
localConnectionTo :: EndPointAddress -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo addr = localConnections >>> DAC.mapMaybe addr
relyViolation :: EndPointPair -> String -> IO a
relyViolation (ourEndPoint, theirEndPoint) str = do
elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation")
fail (str ++ " RELY violation")
elog :: EndPointPair -> String -> IO ()
elog (ourEndPoint, theirEndPoint) msg = do
tid <- myThreadId
putStrLn $ show (localAddress ourEndPoint)
++ "/" ++ show (remoteAddress theirEndPoint)
++ "(" ++ show (remoteId theirEndPoint) ++ ")"
++ "/" ++ show tid
++ ": " ++ msg