module Hans.Tcp.Tcb (
SlowTicks,
TcpTimers(..),
emptyTcpTimers,
resetRetransmit,
retryRetransmit,
stopRetransmit,
reset2MSL,
updateTimers,
calibrateRTO,
State(..),
GetState(..),
whenState,
setState,
CanSend(..),
getSndNxt,
getSndWnd,
CanReceive(..),
getRcvNxt,
getRcvWnd,
getRcvRight,
ListenTcb(..),
newListenTcb,
createChild,
reserveSlot, releaseSlot,
acceptTcb,
Tcb(..),
newTcb,
signalDelayedAck,
setRcvNxt,
finalizeTcb,
getSndUna,
resetIdleTimer,
TcbConfig(..),
usingTimestamps,
disableTimestamp,
queueBytes,
haveBytesAvail,
receiveBytes, tryReceiveBytes,
TimeWaitTcb(..),
mkTimeWaitTcb,
) where
import Hans.Addr (Addr)
import Hans.Buffer.Signal
import qualified Hans.Buffer.Stream as Stream
import Hans.Config (HasConfig(..),Config(..))
import Hans.Lens
import Hans.Network.Types (RouteInfo)
import Hans.Tcp.Packet
import qualified Hans.Tcp.RecvWindow as Recv
import qualified Hans.Tcp.SendWindow as Send
import Control.Monad (when)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.IORef
(IORef,newIORef,atomicModifyIORef',readIORef
,atomicWriteIORef)
import Data.Int (Int64)
import qualified Data.Sequence as Seq
import Data.Time.Clock (NominalDiffTime,getCurrentTime)
import Data.Word (Word16,Word32)
import MonadLib (BaseM(..))
import System.CPUTime (getCPUTime)
type SlowTicks = Int
data TcpTimers = TcpTimers { ttDelayedAck :: !Bool
, tt2MSL :: !SlowTicks
, ttRetransmitValid :: !Bool
, ttRetransmit :: !SlowTicks
, ttRetries :: !Int
, ttRTO :: !SlowTicks
, ttSRTT :: !NominalDiffTime
, ttRTTVar :: !NominalDiffTime
, ttMaxIdle :: !SlowTicks
, ttIdle :: !SlowTicks
}
emptyTcpTimers :: TcpTimers
emptyTcpTimers = TcpTimers { ttDelayedAck = False
, tt2MSL = 0
, ttRetransmitValid = False
, ttRetransmit = 0
, ttRetries = 0
, ttRTO = 2
, ttSRTT = 0
, ttRTTVar = 0
, ttMaxIdle = 10 * 60 * 2
, ttIdle = 0
}
resetRetransmit :: TcpTimers -> (TcpTimers, ())
resetRetransmit TcpTimers { .. } =
(TcpTimers { ttRetransmitValid = True
, ttRetransmit = ttRTO
, ttRetries = 0
, .. }, ())
retryRetransmit :: TcpTimers -> (TcpTimers, ())
retryRetransmit TcpTimers { .. } =
(TcpTimers { ttRetransmitValid = True
, ttRetransmit = ttRTO * 2 ^ retries
, ttRetries = retries
, .. }, ())
where
retries = ttRetries + 1
stopRetransmit :: TcpTimers -> (TcpTimers, ())
stopRetransmit TcpTimers { .. } =
(TcpTimers { ttRetransmitValid = False
, ttRetries = 0
, .. }, ())
reset2MSL :: Config -> TcpTimers -> (TcpTimers, ())
reset2MSL Config { .. } tt = (tt { tt2MSL = 4 * cfgTcpMSL }, ())
updateTimers :: TcpTimers -> (TcpTimers, TcpTimers)
updateTimers tt = (tt',tt)
where
tt' = tt { ttRetransmit = if ttRetransmitValid tt then ttRetransmit tt 1 else 0
, tt2MSL = max 0 (tt2MSL tt 1)
, ttIdle = ttIdle tt + 1
}
resetIdleTimer :: TcpTimers -> (TcpTimers, ())
resetIdleTimer t = (idleReset, ())
where
idleReset = t { ttIdle = 0 }
calibrateRTO :: NominalDiffTime -> TcpTimers -> (TcpTimers, ())
calibrateRTO r tt
| ttSRTT tt > 0 = (rolling, ())
| otherwise = (initial, ())
where
initial = updateRTO tt
{ ttSRTT = r
, ttRTTVar = r / 2
}
alpha = 0.125
beta = 0.25
rttvar = (1 beta) * ttRTTVar tt + beta * abs (ttSRTT tt * r)
srtt = (1 alpha) * ttSRTT tt + alpha * r
rolling = updateRTO tt
{ ttRTTVar = rttvar
, ttSRTT = srtt
}
updateRTO tt' = tt'
{ ttRTO = min 128 (ceiling (ttSRTT tt' + max 0.5 (2 * ttRTTVar tt')))
}
whenState :: (BaseM m IO, GetState tcb) => tcb -> State -> m () -> m ()
whenState tcb state m =
do state' <- inBase (getState tcb)
when (state == state') m
setState :: Tcb -> State -> IO ()
setState tcb state =
do old <- atomicModifyIORef' (tcbState tcb) (\old -> (state,old))
case state of
Established -> tcbEstablished tcb tcb old
Closed -> tcbClosed tcb tcb old
CloseWait -> Stream.closeBuffer (tcbRecvBuffer tcb)
_ -> return ()
class GetState tcb where
getState :: tcb -> IO State
instance GetState ListenTcb where
getState _ = return Listen
instance GetState Tcb where
getState Tcb { .. } = readIORef tcbState
instance GetState TimeWaitTcb where
getState _ = return TimeWait
data State = Listen
| SynSent
| SynReceived
| Established
| FinWait1
| FinWait2
| CloseWait
| Closing
| LastAck
| TimeWait
| Closed
deriving (Eq,Show)
data ListenTcb = ListenTcb { lSrc :: !Addr
, lPort :: !TcpPort
, lAccept :: !(IORef AcceptQueue)
, lAcceptSignal :: !Signal
, lTSClock :: !(IORef Send.TSClock)
}
newListenTcb :: Addr -> TcpPort -> Int -> IO ListenTcb
newListenTcb lSrc lPort aqFree =
do lAccept <- newIORef (AcceptQueue { aqTcbs = Seq.empty, .. })
lAcceptSignal <- newSignal
now <- getCurrentTime
tsval <- getCPUTime
lTSClock <- newIORef (Send.initialTSClock (fromInteger tsval) now)
return ListenTcb { .. }
createChild :: HasConfig cfg
=> cfg -> TcpSeqNum -> ListenTcb -> RouteInfo Addr -> Addr -> TcpHeader
-> (Tcb -> State -> IO ())
-> (Tcb -> State -> IO ())
-> IO Tcb
createChild cxt iss parent ri remote hdr onEstablished onClosed =
do let cfg = view config cxt
now <- getCurrentTime
tsc <- atomicModifyIORef' (lTSClock parent) $ \ tsc ->
let tsc' = Send.updateTSClock cfg now tsc
in (tsc',tsc')
child <- newTcb cfg (Just parent) iss ri (tcpDestPort hdr) remote
(tcpSourcePort hdr) SynReceived tsc
(\c state -> do queueTcb parent c state
onEstablished c state)
(\c state -> do when (state == SynReceived) (releaseSlot parent)
onClosed c state)
atomicWriteIORef (tcbIrs child) (tcpSeqNum hdr)
atomicWriteIORef (tcbIss child) iss
_ <- setRcvNxt (tcpSeqNum hdr + 1) child
_ <- setSndNxt iss child
return child
data AcceptQueue = AcceptQueue { aqFree :: !Int
, aqTcbs :: Seq.Seq Tcb
}
reserveSlot :: ListenTcb -> IO Bool
reserveSlot ListenTcb { .. } =
atomicModifyIORef' lAccept $ \ aq ->
if aqFree aq > 0
then (aq { aqFree = aqFree aq 1 }, True)
else (aq, False)
releaseSlot :: ListenTcb -> IO ()
releaseSlot ListenTcb { .. } =
atomicModifyIORef' lAccept (\ aq -> (aq { aqFree = aqFree aq + 1 }, ()))
queueTcb :: ListenTcb -> Tcb -> State -> IO ()
queueTcb ListenTcb { .. } tcb _ =
do atomicModifyIORef' lAccept $ \ aq ->
(aq { aqTcbs = aqTcbs aq Seq.|> tcb }, ())
signal lAcceptSignal
acceptTcb :: ListenTcb -> IO Tcb
acceptTcb ListenTcb { .. } =
do waitSignal lAcceptSignal
atomicModifyIORef' lAccept $ \ AcceptQueue { .. } ->
case Seq.viewl aqTcbs of
tcb Seq.:< tcbs ->
(AcceptQueue { aqTcbs = tcbs, aqFree = aqFree + 1 }, tcb)
Seq.EmptyL ->
error "Accept queue signaled with an empty queue"
type SeqNumVar = IORef TcpSeqNum
data TcbConfig = TcbConfig { tcUseTimestamp :: !Bool
}
defaultTcbConfig :: TcbConfig
defaultTcbConfig =
TcbConfig { tcUseTimestamp = True
}
usingTimestamps :: Tcb -> IO Bool
usingTimestamps Tcb { .. } =
do TcbConfig { .. } <- readIORef tcbConfig
return tcUseTimestamp
disableTimestamp :: Tcb -> IO ()
disableTimestamp Tcb { .. } =
atomicModifyIORef' tcbConfig $ \ TcbConfig { .. } ->
(TcbConfig { tcUseTimestamp = False, .. }, ())
data Tcb = Tcb { tcbParent :: Maybe ListenTcb
, tcbConfig :: !(IORef TcbConfig)
, tcbState :: !(IORef State)
, tcbEstablished :: Tcb -> State -> IO ()
, tcbClosed :: Tcb -> State -> IO ()
, tcbSndUp :: !SeqNumVar
, tcbSndWl1 :: !SeqNumVar
, tcbSndWl2 :: !SeqNumVar
, tcbIss :: !SeqNumVar
, tcbSendWindow :: !(IORef Send.Window)
, tcbRcvUp :: !SeqNumVar
, tcbIrs :: !SeqNumVar
, tcbNeedsDelayedAck :: !(IORef Bool)
, tcbRecvWindow :: !(IORef Recv.Window)
, tcbRecvBuffer :: !Stream.Buffer
, tcbLocalPort :: !TcpPort
, tcbRemotePort :: !TcpPort
, tcbRouteInfo :: !(RouteInfo Addr)
, tcbRemote :: !Addr
, tcbMss :: !(IORef Int64)
, tcbTimers :: !(IORef TcpTimers)
, tcbTSRecent :: !(IORef Word32)
, tcbLastAckSent :: !(IORef TcpSeqNum)
}
newTcb :: HasConfig state
=> state
-> Maybe ListenTcb
-> TcpSeqNum
-> RouteInfo Addr -> TcpPort -> Addr -> TcpPort
-> State
-> Send.TSClock
-> (Tcb -> State -> IO ())
-> (Tcb -> State -> IO ())
-> IO Tcb
newTcb cxt tcbParent iss tcbRouteInfo tcbLocalPort tcbRemote tcbRemotePort
state tsc tcbEstablished tcbClosed =
do let Config { .. } = view config cxt
tcbConfig <- newIORef defaultTcbConfig
tcbState <- newIORef state
tcbSndUp <- newIORef 0
tcbSndWl1 <- newIORef 0
tcbSndWl2 <- newIORef 0
tcbSendWindow <-
newIORef (Send.emptyWindow iss (fromIntegral cfgTcpInitialWindow) tsc)
tcbIss <- newIORef iss
tcbRecvWindow <- newIORef (Recv.emptyWindow 0 (fromIntegral cfgTcpInitialWindow))
tcbRecvBuffer <- Stream.newBuffer
tcbRcvUp <- newIORef 0
tcbNeedsDelayedAck <- newIORef False
tcbIrs <- newIORef 0
tcbMss <- newIORef (fromIntegral cfgTcpInitialMSS)
tcbTimers <- newIORef emptyTcpTimers
tcbTSRecent <- newIORef 0
tcbLastAckSent <- newIORef 0
return Tcb { .. }
signalDelayedAck :: Tcb -> IO ()
signalDelayedAck Tcb { .. } = atomicWriteIORef tcbNeedsDelayedAck True
setRcvNxt :: TcpSeqNum -> Tcb -> IO Bool
setRcvNxt rcvNxt Tcb { .. } =
atomicModifyIORef' tcbRecvWindow (Recv.setRcvNxt rcvNxt)
setSndNxt :: TcpSeqNum -> Tcb -> IO Bool
setSndNxt sndNxt Tcb { .. } =
atomicModifyIORef' tcbSendWindow (Send.setSndNxt sndNxt)
finalizeTcb :: Tcb -> IO ()
finalizeTcb Tcb { .. } =
do Stream.closeBuffer tcbRecvBuffer
atomicModifyIORef' tcbTimers stopRetransmit
atomicModifyIORef' tcbSendWindow Send.flushWindow
queueBytes :: S.ByteString -> Tcb -> IO ()
queueBytes bytes Tcb { .. } = Stream.putBytes bytes tcbRecvBuffer
haveBytesAvail :: Tcb -> IO Bool
haveBytesAvail Tcb { .. } =
Stream.bytesAvailable tcbRecvBuffer
receiveBytes :: Int -> Tcb -> IO L.ByteString
receiveBytes len Tcb { .. } =
do bytes <- Stream.takeBytes len tcbRecvBuffer
atomicModifyIORef' tcbRecvWindow
(Recv.moveRcvRight (fromIntegral (L.length bytes)))
return bytes
tryReceiveBytes :: Int -> Tcb -> IO (Maybe L.ByteString)
tryReceiveBytes len Tcb { .. } =
do mbBytes <- Stream.tryTakeBytes len tcbRecvBuffer
case mbBytes of
Just bytes ->
do atomicModifyIORef' tcbRecvWindow
(Recv.moveRcvRight (fromIntegral (L.length bytes)))
return (Just bytes)
Nothing ->
return Nothing
data TimeWaitTcb = TimeWaitTcb { twSndNxt :: !TcpSeqNum
, twRcvNxt :: !SeqNumVar
, twRcvWnd :: !Word16
, twLocalPort :: !TcpPort
, twRemotePort :: !TcpPort
, twRouteInfo :: !(RouteInfo Addr)
, twRemote :: !Addr
} deriving (Eq)
mkTimeWaitTcb :: Tcb -> IO TimeWaitTcb
mkTimeWaitTcb Tcb { .. } =
do send <- readIORef tcbSendWindow
recv <- readIORef tcbRecvWindow
twRcvNxt <- newIORef (view Recv.rcvNxt recv)
return $! TimeWaitTcb { twSndNxt = view Send.sndNxt send
, twRcvWnd = view Recv.rcvWnd recv
, twLocalPort = tcbLocalPort
, twRemotePort = tcbRemotePort
, twRouteInfo = tcbRouteInfo
, twRemote = tcbRemote
, .. }
getSndNxt :: (BaseM io IO, CanSend sock) => sock -> io TcpSeqNum
getSndNxt sock =
do (nxt,_) <- getSendWindow sock
return nxt
getSndWnd :: (BaseM io IO, CanSend sock) => sock -> io TcpSeqNum
getSndWnd sock =
do (_,wnd) <- getSendWindow sock
return wnd
class CanSend sock where
getSendWindow :: BaseM io IO => sock -> io (TcpSeqNum,TcpSeqNum)
instance CanSend (IORef Send.Window) where
getSendWindow ref =
do sw <- inBase (readIORef ref)
return (view Send.sndNxt sw, view Send.sndWnd sw)
instance CanSend Tcb where
getSendWindow Tcb { .. } = getSendWindow tcbSendWindow
getSndUna :: BaseM io IO => Tcb -> io TcpSeqNum
getSndUna Tcb { .. } =
do sw <- inBase (readIORef tcbSendWindow)
return $! view Send.sndUna sw
getRcvNxt :: (BaseM io IO, CanReceive sock) => sock -> io TcpSeqNum
getRcvNxt sock =
do (nxt,_) <- getRecvWindow sock
return nxt
getRcvWnd :: (BaseM io IO, CanReceive sock) => sock -> io Word16
getRcvWnd sock =
do (nxt,right) <- getRecvWindow sock
return (fromTcpSeqNum (right nxt))
getRcvRight :: (BaseM io IO, CanReceive sock) => sock -> io TcpSeqNum
getRcvRight sock =
do (_,right) <- getRecvWindow sock
return right
class CanReceive sock where
getRecvWindow :: BaseM io IO => sock -> io (TcpSeqNum,TcpSeqNum)
instance CanReceive (IORef Recv.Window) where
getRecvWindow ref = inBase $
do rw <- readIORef ref
return (view Recv.rcvNxt rw, view Recv.rcvRight rw)
instance CanReceive Tcb where
getRecvWindow Tcb { .. } = getRecvWindow tcbRecvWindow
instance CanReceive TimeWaitTcb where
getRecvWindow TimeWaitTcb { .. } = inBase $
do rcvNxt <- readIORef twRcvNxt
return (rcvNxt,rcvNxt + fromIntegral twRcvWnd)