module Util.BufferedSocket ( BufferedSocket()
, makeBuffered
, socketPort
, socketRecv
, pushback
, socketSend
, socketClose
)
where
import Control.Monad (unless)
import qualified Data.ByteString as BS (ByteString, empty, append, null, length, splitAt)
import qualified Data.ByteString.Lazy as BL (ByteString)
import qualified Network.Socket as S (Socket, PortNumber, socketPort, close)
import qualified Network.Socket.ByteString as NBS (recv)
import qualified Network.Socket.ByteString.Lazy as NBL (sendAll)
import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef')
import Util.IOx
newtype BufferedSocket = BufferedSocket (S.Socket, IORef BS.ByteString)
makeBuffered :: S.Socket -> IOx BufferedSocket
makeBuffered sock = do
bufIO <- liftIOx $ newIORef BS.empty
return $ BufferedSocket (sock, bufIO)
socketPort:: BufferedSocket -> IOx S.PortNumber
socketPort (BufferedSocket (sock, _)) = toIOx $ do
S.socketPort sock
socketRecv :: BufferedSocket -> Int -> IOx BS.ByteString
socketRecv (BufferedSocket (sock, bufIO)) len
| len < 0 = error $ "Bad length: " ++ show len
| len == 0 = return BS.empty
| otherwise = toIOx $ do
buf <- readIORef bufIO
if BS.null buf
then do
NBS.recv sock len
else do
let bufLen = BS.length buf
if len > bufLen
then do
writeIORef bufIO BS.empty
return buf
else do
let (buf0, buf1) = BS.splitAt len buf
writeIORef bufIO buf1
return buf0
pushback :: BufferedSocket -> BS.ByteString -> IOx ()
pushback (BufferedSocket (_, bufIO)) buf0 = do
unless (BS.null buf0) $ toIOx $ do
modifyIORef' bufIO $ BS.append buf0
socketSend :: BufferedSocket -> BL.ByteString -> IOx ()
socketSend (BufferedSocket (sock, _)) bl = toIOx $ do
NBL.sendAll sock bl
socketClose :: BufferedSocket -> IOx ()
socketClose (BufferedSocket (sock, bufIO)) = toIOx $ do
writeIORef bufIO BS.empty
S.close sock