{-# LANGUAGE RecordWildCards, DeriveDataTypeable, OverloadedStrings #-}
module Database.Redis.ProtocolPipelining (
Connection,
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush,
ConnectionLostException(..),
HostName, PortNumber
) where
import Prelude
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.IORef
import Data.Typeable
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO
import System.IO.Error
import System.IO.Unsafe
import Database.Redis.Protocol
type HostName = NS.HostName
type PortNumber = NS.PortNumber
data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context
data Connection = Conn
{ connCtx :: ConnectionContext
, connReplies :: IORef [Reply]
, connPending :: IORef [Reply]
, connPendingCnt :: IORef Int
}
data ConnectionLostException = ConnectionLost
deriving (Show, Typeable)
instance Exception ConnectionLostException
data ConnectPhase
= PhaseUnknown
| PhaseResolve
| PhaseOpenSocket
deriving (Show)
data ConnectTimeout = ConnectTimeout ConnectPhase
deriving (Show, Typeable)
instance Exception ConnectTimeout
getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo hostname port = do
addresses <- NS.getAddrInfo
(Just NS.defaultHints)
(Just hostname)
(Just (show port))
return addresses
connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket addresses = do
let addrInfo = head addresses
socket <- NS.socket (NS.addrFamily addrInfo) NS.Stream NS.defaultProtocol
catch
(do
_ <- NS.connect socket (NS.addrAddress addrInfo)
return socket)
(\(SomeException e) -> do
_ <- NS.close socket
case (tail addresses) of
[] -> throwIO e
others -> connectSocket others)
connect :: NS.HostName -> NS.PortNumber -> Maybe Int -> IO Connection
connect hostName portNumber timeoutOpt =
bracketOnError hConnect hClose $ \h -> do
hSetBinaryMode h True
connReplies <- newIORef []
connPending <- newIORef []
connPendingCnt <- newIORef 0
let connCtx = NormalHandle h
return Conn{..}
where
hConnect = do
phaseMVar <- newMVar PhaseUnknown
let doConnect = hConnect' phaseMVar
case timeoutOpt of
Nothing -> doConnect
Just micros -> do
result <- race doConnect (threadDelay micros)
case result of
Left h -> return h
Right () -> do
phase <- readMVar phaseMVar
errConnectTimeout phase
hConnect' mvar =
do
addrInfo <- getHostAddrInfo hostName portNumber
sock <- connectSocket addrInfo
NS.setSocketOption sock NS.KeepAlive 1
void $ swapMVar mvar PhaseResolve
void $ swapMVar mvar PhaseOpenSocket
NS.socketToHandle sock ReadWriteMode
enableTLS :: TLS.ClientParams -> Connection -> IO Connection
enableTLS tlsParams conn@Conn{..} = do
case connCtx of
NormalHandle h -> do
ctx <- TLS.contextNew h tlsParams
TLS.handshake ctx
return $ conn { connCtx = TLSContext ctx }
TLSContext _ -> return conn
beginReceiving :: Connection -> IO ()
beginReceiving conn = do
rs <- connGetReplies conn
writeIORef (connReplies conn) rs
writeIORef (connPending conn) rs
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
case connCtx of
NormalHandle h -> do
open <- hIsOpen h
when open $ hClose h
TLSContext ctx -> do
TLS.bye ctx
TLS.contextClose ctx
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
case connCtx of
NormalHandle h ->
ioErrorToConnLost $ S.hPut h s
TLSContext ctx ->
ioErrorToConnLost $ TLS.sendData ctx (L.fromStrict s)
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
when (n >= 1000) $ do
r:_ <- readIORef connPending
r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r
flush :: Connection -> IO ()
flush Conn{..} =
case connCtx of
NormalHandle h -> hFlush h
TLSContext ctx -> TLS.contextFlush ctx
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies conn@Conn{..} = go S.empty (SingleLine "previous of first")
where
go rest previous = do
~(r, rest') <- unsafeInterleaveIO $ do
previous `seq` return ()
scanResult <- Scanner.scanWith readMore reply rest
case scanResult of
Scanner.Fail{} -> errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n-1), ())
return (r, rest')
rs <- unsafeInterleaveIO (go rest' r)
return (r:rs)
readMore = ioErrorToConnLost $ do
flush conn
case connCtx of
NormalHandle h -> S.hGetSome h 4096
TLSContext ctx -> TLS.recvData ctx
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout phase = throwIO $ ConnectTimeout phase