module Hans.Layer.Tcp.Types where
import Hans.Address.IP4
import Hans.Layer.Tcp.Window
import Hans.Message.Tcp
import Hans.Ports
import Control.Exception (Exception,SomeException,toException)
import Control.Monad ( guard )
import Data.Time.Clock.POSIX (POSIXTime)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Word (Word16,Word32)
import qualified Data.Map.Strict as Map
import qualified Data.Sequence as Seq
data Host = Host
{ hostConnections :: Connections
, hostTimeWaits :: TimeWaitConnections
, hostInitialSeqNum :: !TcpSeqNum
, hostPorts :: !(PortManager TcpPort)
, hostLastUpdate :: POSIXTime
}
emptyHost :: POSIXTime -> Host
emptyHost start = Host
{ hostConnections = Map.empty
, hostTimeWaits = Map.empty
, hostInitialSeqNum = 0
, hostPorts = emptyPortManager [32768 .. 61000]
, hostLastUpdate = start
}
takePort :: Host -> Maybe (TcpPort,Host)
takePort host = do
(p,ps) <- nextPort (hostPorts host)
return (p, host { hostPorts = ps })
releasePort :: TcpPort -> Host -> Host
releasePort p host = fromMaybe host $ do
ps <- unreserve p (hostPorts host)
return host { hostPorts = ps }
type Connections = Map.Map SocketId TcpSocket
removeClosed :: Connections -> Connections
removeClosed =
Map.filter (\tcp -> tcpState tcp /= Closed || not (tcpUserClosed tcp))
data SocketId = SocketId
{ sidLocalPort :: !TcpPort
, sidRemotePort :: !TcpPort
, sidRemoteHost :: !IP4
} deriving (Eq,Show,Ord)
emptySocketId :: SocketId
emptySocketId = SocketId
{ sidLocalPort = TcpPort 0
, sidRemotePort = TcpPort 0
, sidRemoteHost = IP4 0 0 0 0
}
listenSocketId :: TcpPort -> SocketId
listenSocketId port = emptySocketId { sidLocalPort = port }
incomingSocketId :: IP4 -> TcpHeader -> SocketId
incomingSocketId remote hdr = SocketId
{ sidLocalPort = tcpDestPort hdr
, sidRemotePort = tcpSourcePort hdr
, sidRemoteHost = remote
}
data SocketResult a
= SocketResult a
| SocketError SomeException
deriving (Show)
socketError :: Exception e => e -> SocketResult a
socketError = SocketError . toException
type TimeWaitConnections = Map.Map SocketId TimeWaitSock
data TimeWaitSock = TimeWaitSock { tw2MSL :: !SlowTicks
, twInit2MSL :: !SlowTicks
, twSeqNum :: !TcpSeqNum
, twTimestamp :: Maybe Timestamp
} deriving (Show)
twReset2MSL :: TimeWaitSock -> TimeWaitSock
twReset2MSL tw = tw { tw2MSL = twInit2MSL tw }
mkTimeWait :: TcpSocket -> TimeWaitSock
mkTimeWait TcpSocket { .. } =
TimeWaitSock { tw2MSL = timeout
, twInit2MSL = timeout
, twSeqNum = tcpSndNxt
, twTimestamp = tcpTimestamp
}
where
timeout | tt2MSL tcpTimers <= 0 = 2 * mslTimeout
| otherwise = tt2MSL tcpTimers
addTimeWait :: TcpSocket -> TimeWaitConnections -> TimeWaitConnections
addTimeWait tcp socks = Map.insert (tcpSocketId tcp) (mkTimeWait tcp) socks
stepTimeWaitConnections :: TimeWaitConnections -> TimeWaitConnections
stepTimeWaitConnections = Map.mapMaybe $ \ tw ->
do guard (tw2MSL tw > 0)
return tw { tw2MSL = tw2MSL tw 1 }
type SlowTicks = Int
mslTimeout :: SlowTicks
mslTimeout = 2 * 60
data TcpTimers = TcpTimers
{ ttDelayedAck :: !Bool
, tt2MSL :: !SlowTicks
, ttRTO :: !SlowTicks
, ttSRTT :: !POSIXTime
, ttRTTVar :: !POSIXTime
, ttMaxIdle :: !SlowTicks
, ttIdle :: !SlowTicks
}
emptyTcpTimers :: TcpTimers
emptyTcpTimers = TcpTimers
{ ttDelayedAck = False
, tt2MSL = 0
, ttRTO = 2
, ttSRTT = 0
, ttRTTVar = 0
, ttMaxIdle = 10 * 60 * 2
, ttIdle = 0
}
data Timestamp = Timestamp
{ tsTimestamp :: !Word32
, tsLastTimestamp :: !Word32
, tsGranularity :: !POSIXTime
, tsLastUpdate :: !POSIXTime
} deriving (Show)
emptyTimestamp :: POSIXTime -> Timestamp
emptyTimestamp start = Timestamp
{ tsTimestamp = 0
, tsLastTimestamp = 0
, tsGranularity = 200
, tsLastUpdate = start
}
stepTimestamp :: POSIXTime -> Timestamp -> Timestamp
stepTimestamp now ts
| diff == 0 = ts
| otherwise = ts
{ tsLastUpdate = now
, tsTimestamp = tsTimestamp ts + diff
}
where
diff = ceiling (tsGranularity ts * (now tsLastUpdate ts))
mkTimestamp :: Timestamp -> TcpOption
mkTimestamp ts = OptTimestamp (tsTimestamp ts) (tsLastTimestamp ts)
type Acceptor = SocketId -> IO ()
type Notify = Bool -> IO ()
type Close = IO ()
data TcpSocket = TcpSocket
{ tcpParent :: Maybe SocketId
, tcpSocketId :: !SocketId
, tcpState :: !ConnState
, tcpAcceptors :: Seq.Seq Acceptor
, tcpNotify :: Maybe Notify
, tcpIss :: !TcpSeqNum
, tcpSndNxt :: !TcpSeqNum
, tcpSndUna :: !TcpSeqNum
, tcpUserClosed :: Bool
, tcpOut :: RemoteWindow
, tcpOutBuffer :: Buffer Outgoing
, tcpOutMSS :: !Int64
, tcpIn :: LocalWindow
, tcpInBuffer :: Buffer Incoming
, tcpInMSS :: !Int64
, tcpTimers :: !TcpTimers
, tcpTimestamp :: Maybe Timestamp
, tcpSack :: Bool
, tcpWindowScale :: Bool
}
emptyTcpSocket :: Word16 -> Int -> TcpSocket
emptyTcpSocket sendWindow sendScale = TcpSocket
{ tcpParent = Nothing
, tcpSocketId = emptySocketId
, tcpState = Closed
, tcpAcceptors = Seq.empty
, tcpNotify = Nothing
, tcpIss = 0
, tcpSndNxt = 0
, tcpSndUna = 0
, tcpUserClosed = False
, tcpOut = emptyRemoteWindow sendWindow sendScale
, tcpOutBuffer = emptyBuffer 16384
, tcpOutMSS = defaultMSS
, tcpIn = emptyLocalWindow 0 14600 0
, tcpInBuffer = emptyBuffer 16384
, tcpInMSS = defaultMSS
, tcpTimers = emptyTcpTimers
, tcpTimestamp = Nothing
, tcpSack = True
, tcpWindowScale = True
}
defaultMSS :: Int64
defaultMSS = 1460
nothingOutstanding :: TcpSocket -> Bool
nothingOutstanding TcpSocket { .. } = tcpSndUna == tcpSndNxt
tcpRcvNxt :: TcpSocket -> TcpSeqNum
tcpRcvNxt = lwRcvNxt . tcpIn
inRcvWnd :: TcpSeqNum -> TcpSocket -> Bool
inRcvWnd (TcpSeqNum sn) TcpSocket { .. } =
rcvNxt <= sn && sn < rcvNxt + fromIntegral (lwRcvWind tcpIn)
where
TcpSeqNum rcvNxt = lwRcvNxt tcpIn
nextSegSize :: TcpSocket -> Int64
nextSegSize tcp =
min (fromIntegral (rwAvailable (tcpOut tcp))) (tcpOutMSS tcp)
isAccepting :: TcpSocket -> Bool
isAccepting = not . Seq.null . tcpAcceptors
needsDelayedAck :: TcpSocket -> Bool
needsDelayedAck = ttDelayedAck . tcpTimers
mkMSS :: TcpSocket -> TcpOption
mkMSS tcp = OptMaxSegmentSize (fromIntegral (tcpInMSS tcp))
data ConnState
= Closed
| Listen
| SynSent
| SynReceived
| Established
| CloseWait
| FinWait1
| FinWait2
| Closing
| LastAck
| TimeWait
deriving (Show,Eq,Ord)