module Network.Transport.TCP.Internal
( forkServer
, recvWithLength
, recvExact
, recvInt32
, tryCloseSocket
) where
#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Network.Transport.Internal (decodeInt32, void, tryIO, forkIOWithUnmask)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, ServiceName
, Socket
, SocketType(Stream)
, SocketOption(ReuseAddr)
, getAddrInfo
, defaultHints
, socket
, bindSocket
, listen
, addrFamily
, addrAddress
, defaultProtocol
, setSocketOption
, accept
, sClose
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif
import Control.Concurrent (ThreadId)
import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Applicative ((<$>))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.Int (Int32)
import Data.ByteString.Lazy.Internal (smallChunkSize)
forkServer :: N.HostName
-> N.ServiceName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (N.Socket -> IO ())
-> IO ThreadId
forkServer host port backlog reuseAddr terminationHandler requestHandler = do
addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just host) (Just port)
bracketOnError (N.socket (N.addrFamily addr) N.Stream N.defaultProtocol)
tryCloseSocket $ \sock -> do
when reuseAddr $ N.setSocketOption sock N.ReuseAddr 1
N.bindSocket sock (N.addrAddress addr)
N.listen sock backlog
mask_ $ forkIOWithUnmask $ \unmask ->
catch (unmask (forever $ acceptRequest sock)) $ \ex -> do
tryCloseSocket sock
terminationHandler ex
where
acceptRequest :: N.Socket -> IO ()
acceptRequest sock = bracketOnError (N.accept sock)
(tryCloseSocket . fst)
(requestHandler . fst)
recvWithLength :: N.Socket -> IO [ByteString]
recvWithLength sock = recvInt32 sock >>= recvExact sock
recvInt32 :: Num a => N.Socket -> IO a
recvInt32 sock = decodeInt32 . BS.concat <$> recvExact sock 4
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket sock = void . tryIO $
N.sClose sock
recvExact :: N.Socket
-> Int32
-> IO [ByteString]
recvExact _ len | len < 0 = throwIO (userError "recvExact: Negative length")
recvExact sock len = go [] len
where
go :: [ByteString] -> Int32 -> IO [ByteString]
go acc 0 = return (reverse acc)
go acc l = do
bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
if BS.null bs
then throwIO (userError "recvExact: Socket closed")
else go (bs : acc) (l fromIntegral (BS.length bs))