module Hans.Tcp.State (
HasTcpState(..), TcpState(),
newTcpState,
tcpQueue,
TcpResponderRequest(..),
incrSynBacklog,
decrSynBacklog,
registerListening,
lookupListening,
deleteListening,
Key(), tcbKey,
tcpActive,
lookupActive,
registerActive,
closeActive,
deleteActive,
registerTimeWait,
lookupTimeWait,
resetTimeWait,
deleteTimeWait,
nextTcpPort,
nextIss,
) where
import Hans.Addr (Addr,wildcardAddr,putAddr)
import Hans.Config (HasConfig(..),Config(..))
import qualified Hans.HashTable as HT
import Hans.Lens
import Hans.Network.Types (RouteInfo(..))
import Hans.Tcp.Packet
import Hans.Tcp.Tcb
import Hans.Threads (forkNamed)
import Hans.Time
import Control.Concurrent (threadDelay,MVar,newMVar,modifyMVar)
import qualified Control.Concurrent.BoundedChan as BC
import Control.Monad (guard)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Digest.Pure.SHA(sha1,integerDigest)
import qualified Data.Foldable as F
import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.IORef (IORef,newIORef,atomicModifyIORef',readIORef)
import Data.Serialize (runPutLazy,putByteString)
import Data.Time.Clock (UTCTime,getCurrentTime,addUTCTime,diffUTCTime)
import Data.Word (Word32)
import GHC.Generics (Generic)
import System.Random (newStdGen,random,randoms)
data ListenKey = ListenKey !Addr !TcpPort
deriving (Show,Eq,Ord,Generic)
listenKey :: Getting r ListenTcb ListenKey
listenKey = to (\ ListenTcb { .. } -> ListenKey lSrc lPort)
data Key = Key !Addr
!TcpPort
!Addr
!TcpPort
deriving (Show,Eq,Ord,Generic)
tcbKey :: Getting r Tcb Key
tcbKey = to $ \Tcb { tcbRouteInfo = RouteInfo { .. }, .. } ->
Key tcbRemote tcbRemotePort riSource tcbLocalPort
instance Hashable ListenKey
instance Hashable Key
type TimeWaitHeap = ExpireHeap TimeWaitTcb
data TcpState =
TcpState { tcpListen_ :: !(HT.HashTable ListenKey ListenTcb)
, tcpActive_ :: !(HT.HashTable Key Tcb)
, tcpTimeWait_ :: !(IORef TimeWaitHeap)
, tcpSynBacklog_ :: !(IORef Int)
, tcpPorts :: !(MVar TcpPort)
, tcpISSTimer :: !(IORef Tcp4USTimer)
, tcpQueue_ :: !(BC.BoundedChan TcpResponderRequest)
}
data TcpResponderRequest = SendSegment !(RouteInfo Addr) !Addr !TcpHeader !L.ByteString
| SendWithTcb !Tcb !TcpHeader !L.ByteString
tcpQueue :: HasTcpState state => Getting r state (BC.BoundedChan TcpResponderRequest)
tcpQueue = tcpState . to tcpQueue_
tcpListen :: HasTcpState state => Getting r state (HT.HashTable ListenKey ListenTcb)
tcpListen = tcpState . to tcpListen_
tcpActive :: HasTcpState state => Getting r state (HT.HashTable Key Tcb)
tcpActive = tcpState . to tcpActive_
tcpTimeWait :: HasTcpState state => Getting r state (IORef TimeWaitHeap)
tcpTimeWait = tcpState . to tcpTimeWait_
tcpSynBacklog :: HasTcpState state => Getting r state (IORef Int)
tcpSynBacklog = tcpState . to tcpSynBacklog_
class HasTcpState state where
tcpState :: Getting r state TcpState
instance HasTcpState TcpState where
tcpState = id
data Tcp4USTimer = Tcp4USTimer { tcpTimer :: !Word32
, tcpSecret :: !S.ByteString
, tcpLastUpdate :: !UTCTime
}
newTcp4USTimer :: IO Tcp4USTimer
newTcp4USTimer =
do tcpLastUpdate <- getCurrentTime
gen <- newStdGen
let (tcpTimer,gen') = random gen
tcpSecret = S.pack (take 256 (randoms gen'))
return Tcp4USTimer { .. }
newTcpState :: Config -> IO TcpState
newTcpState Config { .. } =
do tcpListen_ <- HT.newHashTable cfgTcpListenTableSize
tcpActive_ <- HT.newHashTable cfgTcpActiveTableSize
tcpTimeWait_ <- newIORef emptyHeap
tcpSynBacklog_ <- newIORef cfgTcpMaxSynBacklog
tcpPorts <- newMVar 32767
tcpISSTimer <- newIORef =<< newTcp4USTimer
tcpQueue_ <- BC.newBoundedChan 128
return TcpState { .. }
decrSynBacklog :: HasTcpState state => state -> IO Bool
decrSynBacklog state =
atomicModifyIORef' (view tcpSynBacklog state) $ \ backlog ->
if backlog > 0
then (backlog 1, True)
else (backlog, False)
incrSynBacklog :: HasTcpState state => state -> IO ()
incrSynBacklog state =
atomicModifyIORef' (view tcpSynBacklog state)
(\ backlog -> (backlog + 1, ()))
registerListening :: HasTcpState state
=> state -> ListenTcb -> IO Bool
registerListening state tcb =
HT.alter update (view listenKey tcb) (view tcpListen state)
where
update Nothing = (Just tcb, True)
update mb@Just{} = (mb, False)
deleteListening :: HasTcpState state
=> state -> ListenTcb -> IO ()
deleteListening state tcb =
HT.delete (view listenKey tcb) (view tcpListen state)
lookupListening :: HasTcpState state
=> state -> Addr -> TcpPort -> IO (Maybe ListenTcb)
lookupListening state src port =
do mb <- HT.lookup (ListenKey src port) (view tcpListen state)
case mb of
Just {} -> return mb
Nothing ->
HT.lookup (ListenKey (wildcardAddr src) port) (view tcpListen state)
registerTimeWait :: (HasConfig state, HasTcpState state)
=> state -> TimeWaitTcb -> IO ()
registerTimeWait state tcb =
let Config { .. } = view config state
in updateTimeWait state $ \ now heap ->
let heap' = if H.size heap >= cfgTcpTimeWaitSocketLimit
then H.deleteMin heap
else heap
in fst (expireAt (addUTCTime cfgTcpTimeoutTimeWait now) tcb heap')
resetTimeWait :: (HasConfig state, HasTcpState state)
=> state -> TimeWaitTcb -> IO ()
resetTimeWait state tcb =
let Config { .. } = view config state
in updateTimeWait state $ \ now heap ->
fst $ expireAt (addUTCTime cfgTcpTimeoutTimeWait now) tcb
$ filterHeap (/= tcb) heap
updateTimeWait :: (HasConfig state, HasTcpState state)
=> state -> (UTCTime -> TimeWaitHeap -> TimeWaitHeap) -> IO ()
updateTimeWait state update =
do now <- getCurrentTime
mbReap <-
atomicModifyIORef' (view tcpTimeWait state) $ \ heap ->
let heap' = update now heap
reaper = do guard (nullHeap heap)
future <- nextEvent heap'
return $ do delayDiff now future
reapLoop
in (heap', reaper)
case mbReap of
Just reaper -> do _ <- forkNamed "TimeWait Reaper" reaper
return ()
Nothing -> return ()
where
delayDiff now future =
threadDelay (max 500000 (toUSeconds (diffUTCTime future now)))
reapLoop =
do now <- getCurrentTime
mbExpire <-
atomicModifyIORef' (view tcpTimeWait state) $ \ heap ->
let heap' = dropExpired now heap
in (heap', nextEvent heap')
case mbExpire of
Just future -> do delayDiff now future
reapLoop
Nothing -> return ()
lookupTimeWait :: HasTcpState state
=> state -> Addr -> TcpPort -> Addr -> TcpPort
-> IO (Maybe TimeWaitTcb)
lookupTimeWait state dst dstPort src srcPort =
do heap <- readIORef (view tcpTimeWait state)
return (payload `fmap` F.find isConn heap)
where
isConn Entry { payload = TimeWaitTcb { .. } } =
and [ twRemote == dst
, twRemotePort == dstPort
, riSource twRouteInfo == src
, twLocalPort == srcPort ]
deleteTimeWait :: HasTcpState state => state -> TimeWaitTcb -> IO ()
deleteTimeWait state tw =
atomicModifyIORef' (view tcpTimeWait state) $ \ heap ->
(filterHeap (/= tw) heap, ())
registerActive :: HasTcpState state => state -> Tcb -> IO Bool
registerActive state tcb =
HT.alter update (view tcbKey tcb) (view tcpActive state)
where
update Nothing = (Just tcb, True)
update mb = (mb, False)
lookupActive :: HasTcpState state
=> state -> Addr -> TcpPort -> Addr -> TcpPort -> IO (Maybe Tcb)
lookupActive state dst dstPort src srcPort =
HT.lookup (Key dst dstPort src srcPort) (view tcpActive state)
closeActive :: HasTcpState state => state -> Tcb -> IO ()
closeActive state tcb =
do finalizeTcb tcb
deleteActive state tcb
deleteActive :: HasTcpState state => state -> Tcb -> IO ()
deleteActive state tcb =
HT.delete (view tcbKey tcb) (view tcpActive state)
nextTcpPort :: HasTcpState state
=> state -> Addr -> Addr -> TcpPort -> IO (Maybe TcpPort)
nextTcpPort state src dst dstPort =
modifyMVar tcpPorts (pickFreshPort tcpActive_ (Key dst dstPort src))
where
TcpState { .. } = view tcpState state
pickFreshPort :: HT.HashTable Key Tcb -> (TcpPort -> Key) -> TcpPort
-> IO (TcpPort, Maybe TcpPort)
pickFreshPort ht mkKey p0 = go 0 p0
where
go :: Int -> TcpPort -> IO (TcpPort,Maybe TcpPort)
go i _ | i > 65535 = return (p0, Nothing)
go i 0 = go (i+1) 1025
go i port =
do used <- HT.hasKey (mkKey port) ht
if not used
then return (port, Just port)
else go (i + 1) (port + 1)
nextIss :: HasTcpState state
=> state -> Addr -> TcpPort -> Addr -> TcpPort -> IO TcpSeqNum
nextIss state src srcPort dst dstPort =
do let TcpState { .. } = view tcpState state
now <- getCurrentTime
(m,f_digest) <- atomicModifyIORef' tcpISSTimer $ \ Tcp4USTimer { .. } ->
let diff = diffUTCTime now tcpLastUpdate
ticks = tcpTimer + truncate (diff * 250000)
timers' = Tcp4USTimer { tcpTimer = ticks
, tcpLastUpdate = now
, .. }
digest = integerDigest $ sha1 $ runPutLazy $
do putAddr src
putTcpPort srcPort
putAddr dst
putTcpPort dstPort
putByteString tcpSecret
in (timers', (ticks, digest))
return (fromIntegral (m + fromIntegral f_digest))