{-# LANGUAGE TypeSynonymInstances #-}
module Network.TCP
( Connection
, EndPoint(..)
, openTCPPort
, isConnectedTo
, openTCPConnection
, socketConnection
, isTCPConnectedTo
, HandleStream
, HStream(..)
, StreamHooks(..)
, nullHooks
, setStreamHooks
, getStreamHooks
, hstreamToConnection
) where
import Network.Socket
( Socket, SocketOption(KeepAlive)
, SocketType(Stream), connect
, shutdown, ShutdownCmd(..)
, setSocketOption, getPeerName
, socket, Family(AF_UNSPEC), defaultProtocol, getAddrInfo
, defaultHints, addrFamily, withSocketsDo
, addrSocketType, addrAddress
)
import qualified Network.Socket
( close )
import qualified Network.Stream as Stream
( Stream(readBlock, readLine, writeBlock, close, closeOnEnd) )
import Network.Stream
( ConnError(..)
, Result
, failWith
, failMisc
)
import Network.BufferType
import Network.HTTP.Base ( catchIO )
import Network.Socket ( socketToHandle )
import Data.Char ( toLower )
import Data.Word ( Word8 )
import Control.Concurrent
import Control.Exception ( onException )
import Control.Monad ( liftM, when )
import System.IO ( Handle, hFlush, IOMode(..), hClose )
import System.IO.Error ( isEOFError )
import qualified Data.ByteString as Strict
import qualified Data.ByteString.Lazy as Lazy
newtype Connection = Connection (HandleStream String)
newtype HandleStream a = HandleStream {getRef :: MVar (Conn a)}
data EndPoint = EndPoint { epHost :: String, epPort :: Int }
instance Eq EndPoint where
EndPoint host1 port1 == EndPoint host2 port2 =
map toLower host1 == map toLower host2 && port1 == port2
data Conn a
= MkConn { connSock :: ! Socket
, connHandle :: Handle
, connBuffer :: BufferOp a
, connInput :: Maybe a
, connEndPoint :: EndPoint
, connHooks :: Maybe (StreamHooks a)
, connCloseEOF :: Bool
}
| ConnClosed
deriving(Eq)
hstreamToConnection :: HandleStream String -> Connection
hstreamToConnection h = Connection h
connHooks' :: Conn a -> Maybe (StreamHooks a)
connHooks' ConnClosed{} = Nothing
connHooks' x = connHooks x
data StreamHooks ty
= StreamHooks
{ hook_readLine :: (ty -> String) -> Result ty -> IO ()
, hook_readBlock :: (ty -> String) -> Int -> Result ty -> IO ()
, hook_writeBlock :: (ty -> String) -> ty -> Result () -> IO ()
, hook_close :: IO ()
, hook_name :: String
}
instance Eq ty => Eq (StreamHooks ty) where
(==) _ _ = True
nullHooks :: StreamHooks ty
nullHooks = StreamHooks
{ hook_readLine = \ _ _ -> return ()
, hook_readBlock = \ _ _ _ -> return ()
, hook_writeBlock = \ _ _ _ -> return ()
, hook_close = return ()
, hook_name = ""
}
setStreamHooks :: HandleStream ty -> StreamHooks ty -> IO ()
setStreamHooks h sh = modifyMVar_ (getRef h) (\ c -> return c{connHooks=Just sh})
getStreamHooks :: HandleStream ty -> IO (Maybe (StreamHooks ty))
getStreamHooks h = readMVar (getRef h) >>= return.connHooks
class BufferType bufType => HStream bufType where
openStream :: String -> Int -> IO (HandleStream bufType)
openSocketStream :: String -> Int -> Socket -> IO (HandleStream bufType)
readLine :: HandleStream bufType -> IO (Result bufType)
readBlock :: HandleStream bufType -> Int -> IO (Result bufType)
writeBlock :: HandleStream bufType -> bufType -> IO (Result ())
close :: HandleStream bufType -> IO ()
closeQuick :: HandleStream bufType -> IO ()
closeOnEnd :: HandleStream bufType -> Bool -> IO ()
instance HStream Strict.ByteString where
openStream = openTCPConnection
openSocketStream = socketConnection
readBlock c n = readBlockBS c n
readLine c = readLineBS c
writeBlock c str = writeBlockBS c str
close c = closeIt c Strict.null True
closeQuick c = closeIt c Strict.null False
closeOnEnd c f = closeEOF c f
instance HStream Lazy.ByteString where
openStream = \ a b -> openTCPConnection_ a b True
openSocketStream = \ a b c -> socketConnection_ a b c True
readBlock c n = readBlockBS c n
readLine c = readLineBS c
writeBlock c str = writeBlockBS c str
close c = closeIt c Lazy.null True
closeQuick c = closeIt c Lazy.null False
closeOnEnd c f = closeEOF c f
instance Stream.Stream Connection where
readBlock (Connection c) = Network.TCP.readBlock c
readLine (Connection c) = Network.TCP.readLine c
writeBlock (Connection c) = Network.TCP.writeBlock c
close (Connection c) = Network.TCP.close c
closeOnEnd (Connection c) f = Network.TCP.closeEOF c f
instance HStream String where
openStream = openTCPConnection
openSocketStream = socketConnection
readBlock ref n = readBlockBS ref n
readLine ref = readLineBS ref
writeBlock ref str = writeBlockBS ref str
close c = closeIt c null True
closeQuick c = closeIt c null False
closeOnEnd c f = closeEOF c f
openTCPPort :: String -> Int -> IO Connection
openTCPPort uri port = openTCPConnection uri port >>= return.Connection
openTCPConnection :: BufferType ty => String -> Int -> IO (HandleStream ty)
openTCPConnection uri port = openTCPConnection_ uri port False
openTCPConnection_ :: BufferType ty => String -> Int -> Bool -> IO (HandleStream ty)
openTCPConnection_ uri port stashInput = do
let fixedUri =
case uri of
'[':(rest@(c:_)) | last rest == ']'
-> if c == 'v' || c == 'V'
then error $ "Unsupported post-IPv6 address " ++ uri
else init rest
_ -> uri
addrinfos <- withSocketsDo $ getAddrInfo (Just $ defaultHints { addrFamily = AF_UNSPEC, addrSocketType = Stream }) (Just fixedUri) (Just . show $ port)
case addrinfos of
[] -> fail "openTCPConnection: getAddrInfo returned no address information"
(a:_) -> do
s <- socket (addrFamily a) Stream defaultProtocol
onException (do
setSocketOption s KeepAlive 1
connect s (addrAddress a)
socketConnection_ fixedUri port s stashInput
) (Network.Socket.close s)
socketConnection :: BufferType ty
=> String
-> Int
-> Socket
-> IO (HandleStream ty)
socketConnection hst port sock = socketConnection_ hst port sock False
socketConnection_ :: BufferType ty
=> String
-> Int
-> Socket
-> Bool
-> IO (HandleStream ty)
socketConnection_ hst port sock stashInput = do
h <- socketToHandle sock ReadWriteMode
mb <- case stashInput of { True -> liftM Just $ buf_hGetContents bufferOps h; _ -> return Nothing }
let conn = MkConn
{ connSock = sock
, connHandle = h
, connBuffer = bufferOps
, connInput = mb
, connEndPoint = EndPoint hst port
, connHooks = Nothing
, connCloseEOF = False
}
v <- newMVar conn
return (HandleStream v)
closeConnection :: HStream a => HandleStream a -> IO Bool -> IO ()
closeConnection ref readL = do
c <- readMVar (getRef ref)
closeConn c `catchIO` (\_ -> return ())
modifyMVar_ (getRef ref) (\ _ -> return ConnClosed)
where
closeConn ConnClosed = return ()
closeConn conn = do
let sk = connSock conn
hFlush (connHandle conn)
shutdown sk ShutdownSend
suck readL
hClose (connHandle conn)
shutdown sk ShutdownReceive
Network.Socket.close sk
suck :: IO Bool -> IO ()
suck rd = do
f <- rd
if f then return () else suck rd
isConnectedTo :: Connection -> EndPoint -> IO Bool
isConnectedTo (Connection conn) endPoint = isTCPConnectedTo conn endPoint
isTCPConnectedTo :: HandleStream ty -> EndPoint -> IO Bool
isTCPConnectedTo conn endPoint = do
v <- readMVar (getRef conn)
case v of
ConnClosed -> return False
_
| connEndPoint v == endPoint ->
catchIO (getPeerName (connSock v) >> return True) (const $ return False)
| otherwise -> return False
readBlockBS :: HStream a => HandleStream a -> Int -> IO (Result a)
readBlockBS ref n = onNonClosedDo ref $ \ conn -> do
x <- bufferGetBlock ref n
maybe (return ())
(\ h -> hook_readBlock h (buf_toStr $ connBuffer conn) n x)
(connHooks' conn)
return x
readLineBS :: HStream a => HandleStream a -> IO (Result a)
readLineBS ref = onNonClosedDo ref $ \ conn -> do
x <- bufferReadLine ref
maybe (return ())
(\ h -> hook_readLine h (buf_toStr $ connBuffer conn) x)
(connHooks' conn)
return x
writeBlockBS :: HandleStream a -> a -> IO (Result ())
writeBlockBS ref b = onNonClosedDo ref $ \ conn -> do
x <- bufferPutBlock (connBuffer conn) (connHandle conn) b
maybe (return ())
(\ h -> hook_writeBlock h (buf_toStr $ connBuffer conn) b x)
(connHooks' conn)
return x
closeIt :: HStream ty => HandleStream ty -> (ty -> Bool) -> Bool -> IO ()
closeIt c p b = do
closeConnection c (if b
then readLineBS c >>= \ x -> case x of { Right xs -> return (p xs); _ -> return True}
else return True)
conn <- readMVar (getRef c)
maybe (return ())
(hook_close)
(connHooks' conn)
closeEOF :: HandleStream ty -> Bool -> IO ()
closeEOF c flg = modifyMVar_ (getRef c) (\ co -> return co{connCloseEOF=flg})
bufferGetBlock :: HStream a => HandleStream a -> Int -> IO (Result a)
bufferGetBlock ref n = onNonClosedDo ref $ \ conn -> do
case connInput conn of
Just c -> do
let (a,b) = buf_splitAt (connBuffer conn) n c
modifyMVar_ (getRef ref) (\ co -> return co{connInput=Just b})
return (return a)
_ -> do
catchIO (buf_hGet (connBuffer conn) (connHandle conn) n >>= return.return)
(\ e ->
if isEOFError e
then do
when (connCloseEOF conn) $ catchIO (closeQuick ref) (\ _ -> return ())
return (return (buf_empty (connBuffer conn)))
else return (failMisc (show e)))
bufferPutBlock :: BufferOp a -> Handle -> a -> IO (Result ())
bufferPutBlock ops h b =
catchIO (buf_hPut ops h b >> hFlush h >> return (return ()))
(\ e -> return (failMisc (show e)))
bufferReadLine :: HStream a => HandleStream a -> IO (Result a)
bufferReadLine ref = onNonClosedDo ref $ \ conn -> do
case connInput conn of
Just c -> do
let (a,b0) = buf_span (connBuffer conn) (/='\n') c
let (newl,b1) = buf_splitAt (connBuffer conn) 1 b0
modifyMVar_ (getRef ref) (\ co -> return co{connInput=Just b1})
return (return (buf_append (connBuffer conn) a newl))
_ -> catchIO
(buf_hGetLine (connBuffer conn) (connHandle conn) >>=
return . return . appendNL (connBuffer conn))
(\ e ->
if isEOFError e
then do
when (connCloseEOF conn) $ catchIO (closeQuick ref) (\ _ -> return ())
return (return (buf_empty (connBuffer conn)))
else return (failMisc (show e)))
where
appendNL ops b = buf_snoc ops b nl
nl :: Word8
nl = fromIntegral (fromEnum '\n')
onNonClosedDo :: HandleStream a -> (Conn a -> IO (Result b)) -> IO (Result b)
onNonClosedDo h act = do
x <- readMVar (getRef h)
case x of
ConnClosed{} -> return (failWith ErrorClosed)
_ -> act x