{-# LANGUAGE CPP #-}
module Network.TLS.Backend
( HasBackend(..)
, Backend(..)
) where
import Network.TLS.Imports
import qualified Data.ByteString as B
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush, hClose)
#ifdef INCLUDE_NETWORK
import qualified Network.Socket as Network (Socket, close)
import qualified Network.Socket.ByteString as Network
#endif
#ifdef INCLUDE_HANS
import qualified Data.ByteString.Lazy as L
import qualified Hans.NetworkStack as Hans
#endif
data Backend = Backend
{ Backend -> IO ()
backendFlush :: IO ()
, Backend -> IO ()
backendClose :: IO ()
, Backend -> ByteString -> IO ()
backendSend :: ByteString -> IO ()
, Backend -> Int -> IO ByteString
backendRecv :: Int -> IO ByteString
}
class HasBackend a where
initializeBackend :: a -> IO ()
getBackend :: a -> Backend
instance HasBackend Backend where
initializeBackend :: Backend -> IO ()
initializeBackend Backend
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
getBackend :: Backend -> Backend
getBackend = Backend -> Backend
forall a. a -> a
id
#if defined(__GLASGOW_HASKELL__) && WINDOWS
#define SOCKET_ACCEPT_RECV_WORKAROUND
#endif
safeRecv :: Network.Socket -> Int -> IO ByteString
#ifndef SOCKET_ACCEPT_RECV_WORKAROUND
safeRecv :: Socket -> Int -> IO ByteString
safeRecv = Socket -> Int -> IO ByteString
Network.recv
#else
safeRecv s buf = do
var <- newEmptyMVar
forkIO $ Network.recv s buf `E.catch` (\(_::IOException) -> return S8.empty) >>= putMVar var
takeMVar var
#endif
#ifdef INCLUDE_NETWORK
instance HasBackend Network.Socket where
initializeBackend :: Socket -> IO ()
initializeBackend Socket
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
getBackend :: Socket -> Backend
getBackend Socket
sock = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Socket -> IO ()
Network.close Socket
sock) (Socket -> ByteString -> IO ()
Network.sendAll Socket
sock) Int -> IO ByteString
recvAll
where recvAll :: Int -> IO ByteString
recvAll Int
n = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> IO [ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop Int
n
where loop :: Int -> IO [ByteString]
loop Int
0 = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return []
loop Int
left = do
ByteString
r <- Socket -> Int -> IO ByteString
safeRecv Socket
sock Int
left
if ByteString -> Bool
B.null ByteString
r
then [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return []
else (ByteString
rByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop (Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
r)
#endif
#ifdef INCLUDE_HANS
instance HasBackend Hans.Socket where
initializeBackend _ = return ()
getBackend sock = Backend (return ()) (Hans.close sock) sendAll recvAll
where sendAll x = do
amt <- fromIntegral <$> Hans.sendBytes sock (L.fromStrict x)
if (amt == 0) || (amt == B.length x)
then return ()
else sendAll (B.drop amt x)
recvAll n = loop (fromIntegral n) L.empty
loop 0 acc = return (L.toStrict acc)
loop left acc = do
r <- Hans.recvBytes sock left
if L.null r
then loop 0 acc
else loop (left - L.length r) (acc `L.append` r)
#endif
instance HasBackend Handle where
initializeBackend :: Handle -> IO ()
initializeBackend Handle
handle = Handle -> BufferMode -> IO ()
hSetBuffering Handle
handle BufferMode
NoBuffering
getBackend :: Handle -> Backend
getBackend Handle
handle = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (Handle -> IO ()
hFlush Handle
handle) (Handle -> IO ()
hClose Handle
handle) (Handle -> ByteString -> IO ()
B.hPut Handle
handle) (Handle -> Int -> IO ByteString
B.hGet Handle
handle)