module Hans.Socket.Tcp where
import Hans.Addr
import qualified Hans.Buffer.Stream as Stream
import qualified Hans.HashTable as HT
import Hans.Lens (Getting,view,to)
import Hans.Network
import Hans.Socket.Types
import Hans.Tcp.Tcb
import Hans.Tcp.Message
import Hans.Tcp.Output
import qualified Hans.Tcp.SendWindow as Send
import Hans.Types
import Control.Concurrent (newEmptyMVar,tryPutMVar,takeMVar,yield)
import Control.Exception (throwIO, handle)
import Control.Monad (unless,when)
import qualified Data.ByteString.Lazy as L
import Data.IORef (readIORef)
import Data.Time.Clock (getCurrentTime)
import Data.Typeable(Typeable)
import System.CPUTime (getCPUTime)
data TcpSocket addr = TcpSocket { tcpNS :: !NetworkStack
, tcpTcb :: !Tcb
}
deriving (Typeable)
instance HasNetworkStack (TcpSocket addr) where
networkStack = to tcpNS
tcpRoute :: NetworkAddr addr => Getting r (TcpSocket addr) (RouteInfo addr)
tcpRoute = to (\ TcpSocket { tcpTcb = Tcb { .. } } -> cast tcbRouteInfo)
where
cast RouteInfo { .. } =
case (fromAddr riSource, fromAddr riNext) of
(Just a,Just b) -> RouteInfo { riSource = a, riNext = b, .. }
_ -> error "tcpRoute: invalid address combination"
tcpLocalAddr :: NetworkAddr addr => Getting r (TcpSocket addr) addr
tcpLocalAddr = tcpRoute . to riSource
tcpLocalPort :: Getting r (TcpSocket addr) SockPort
tcpLocalPort = to (\ TcpSocket { tcpTcb = Tcb { .. } } -> tcbLocalPort )
tcpRemoteAddr :: NetworkAddr addr => Getting r (TcpSocket addr) addr
tcpRemoteAddr = to (\ TcpSocket { tcpTcb = Tcb { .. } } -> cast tcbRemote)
where
cast addr =
case fromAddr addr of
Just a -> a
Nothing -> error "tcpRemoteHost: invalid remote address"
tcpRemotePort :: Getting r (TcpSocket addr) SockPort
tcpRemotePort = to (\ TcpSocket { tcpTcb = Tcb { .. } } -> tcbRemotePort)
activeOpen :: Network addr
=> NetworkStack -> RouteInfo addr -> SockPort -> addr -> SockPort
-> IO Tcb
activeOpen ns ri srcPort dst dstPort =
do let ri' = toAddr `fmap` ri
dst' = toAddr dst
done <- newEmptyMVar
now <- getCurrentTime
tsval <- getCPUTime
let tsc = Send.initialTSClock (fromInteger tsval) now
iss <- nextIss (view tcpState ns) (riSource ri') srcPort dst' dstPort
tcb <- newTcb ns Nothing iss ri' srcPort dst' dstPort Closed tsc
(\_ _ -> tryPutMVar done True >> return ())
(\_ _ -> tryPutMVar done False >> return ())
let update Nothing = (Just tcb, True)
update Just{} = (Nothing, False)
success <- HT.alter update (view tcbKey tcb) (view tcpActive ns)
if success
then
do syn <- mkSyn tcb
_ <- sendWithTcb ns tcb syn L.empty
setState tcb SynSent
established <- takeMVar done
if established
then return tcb
else throwIO ConnectionRefused
else throwIO AlreadyConnected
instance Socket TcpSocket where
sClose TcpSocket { .. } =
do state <- readIORef (tcbState tcpTcb)
case state of
CloseWait ->
do sendFin tcpNS tcpTcb
setState tcpTcb LastAck
Established ->
do sendFin tcpNS tcpTcb
setState tcpTcb FinWait1
_ -> return ()
guardSend :: Tcb -> IO r -> IO r
guardSend tcb send =
do st <- getState tcb
case st of
Closed -> throwIO NoConnection
Listen -> error "guardSend: Listen state for active tcb"
SynReceived -> send
SynSent -> send
Established -> send
CloseWait -> send
_ -> throwIO ConnectionClosing
guardRecv :: Tcb -> IO r -> IO r
guardRecv tcb recv =
do st <- getState tcb
case st of
Closed -> throwIO NoConnection
Listen -> recv
SynSent -> recv
SynReceived -> recv
Established -> recv
FinWait1 -> recv
FinWait2 -> recv
CloseWait -> do avail <- Stream.bytesAvailable (tcbRecvBuffer tcb)
if avail
then recv
else throwIO ConnectionClosing
Closing -> throwIO ConnectionClosing
LastAck -> throwIO ConnectionClosing
TimeWait -> throwIO ConnectionClosing
instance DataSocket TcpSocket where
sConnect ns SocketConfig { .. } mbDev src mbSrcPort dst dstPort =
do let tcpNS = view networkStack ns
ri <- route tcpNS mbDev src dst
srcPort <- case mbSrcPort of
Just port -> return port
Nothing ->
do mb <- nextTcpPort tcpNS (toAddr (riSource ri)) (toAddr dst) dstPort
case mb of
Just port -> return port
Nothing -> throwIO NoPortAvailable
tcpTcb <- activeOpen tcpNS ri srcPort dst dstPort
return TcpSocket { .. }
sCanWrite TcpSocket { .. } =
handle ((\ _ -> return False) :: (ConnectionException -> IO Bool)) $
guardSend tcpTcb (canSend tcpTcb)
sWrite TcpSocket { .. } bytes =
guardSend tcpTcb $ do len <- sendData tcpNS tcpTcb bytes
when (len < L.length bytes) yield
return $! fromIntegral len
sCanRead TcpSocket { .. } = guardRecv tcpTcb (haveBytesAvail tcpTcb)
sRead TcpSocket { .. } len = guardRecv tcpTcb (receiveBytes len tcpTcb)
sTryRead TcpSocket { .. } len = guardRecv tcpTcb (tryReceiveBytes len tcpTcb)
data TcpListenSocket addr = TcpListenSocket { tlNS :: !NetworkStack
, tlTcb :: !ListenTcb
}
instance Socket TcpListenSocket where
sClose TcpListenSocket { .. } = deleteListening tlNS tlTcb
instance ListenSocket TcpListenSocket where
type Client TcpListenSocket = TcpSocket
sListen ns SocketConfig { .. } src srcPort backlog =
do let tlNS = view networkStack ns
tlTcb <- newListenTcb (toAddr src) srcPort backlog
created <- registerListening tlNS tlTcb
unless created (throwIO AlreadyListening)
return $! TcpListenSocket { .. }
sAccept TcpListenSocket { .. } =
do tcpTcb <- acceptTcb tlTcb
return $! TcpSocket { tcpNS = tlNS, .. }