{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Database.Redis.IO.Connection
( Connection
, settings
, resolve
, connect
, close
, request
, sync
, send
, receive
) where
import Control.Applicative
import Control.Exception
import Control.Monad
import Data.Attoparsec.ByteString hiding (Result)
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toChunks)
import Data.Foldable (foldlM, toList)
import Data.IORef
import Data.Maybe (isJust)
import Data.Redis
import Data.Sequence (Seq, (|>))
import Data.Int
import Data.Word
import Foreign.C.Types (CInt (..))
import Database.Redis.IO.Settings
import Database.Redis.IO.Types
import Database.Redis.IO.Timeouts (TimeoutManager, withTimeout)
import Network.Socket hiding (connect, close, send, recv)
import Network.Socket.ByteString (recv, sendMany)
import System.Logger hiding (Settings, settings, close)
import System.Timeout
import Prelude
import qualified Data.ByteString.Lazy.Char8 as Char8
import qualified Data.Sequence as Seq
import qualified Network.Socket as S
data Connection = Connection
{ settings :: !Settings
, logger :: !Logger
, timeouts :: !TimeoutManager
, address :: !InetAddr
, sock :: !Socket
, leftover :: !(IORef ByteString)
, buffer :: !(IORef (Seq (Resp, IORef Resp)))
, trxState :: !(IORef (Bool, [IORef Resp]))
}
instance Show Connection where
show = Char8.unpack . eval . bytes
instance ToBytes Connection where
bytes c = bytes (address c) +++ val "#" +++ fd (sock c)
resolve :: String -> Word16 -> IO [InetAddr]
resolve host port =
map (InetAddr . addrAddress) <$> getAddrInfo (Just hints) (Just host) (Just (show port))
where
hints = defaultHints { addrFlags = [AI_ADDRCONFIG], addrSocketType = Stream }
connect :: Settings -> Logger -> TimeoutManager -> InetAddr -> IO Connection
connect t g m a = bracketOnError mkSock S.close $ \s -> do
ok <- timeout (ms (sConnectTimeout t) * 1000) (S.connect s (sockAddr a))
unless (isJust ok) $
throwIO ConnectTimeout
Connection t g m a s
<$> newIORef ""
<*> newIORef Seq.empty
<*> newIORef (False, [])
where
mkSock = socket (familyOf $ sockAddr a) Stream defaultProtocol
familyOf (SockAddrInet _ _ ) = AF_INET
familyOf (SockAddrInet6 _ _ _ _) = AF_INET6
familyOf (SockAddrUnix _ ) = AF_UNIX
#if !MIN_VERSION_network(3,0,0)
familyOf (SockAddrCan _ ) = AF_CAN
#endif
close :: Connection -> IO ()
close = S.close . sock
request :: Resp -> IORef Resp -> Connection -> IO ()
request x y c = modifyIORef' (buffer c) (|> (x, y))
sync :: Connection -> IO ()
sync c = do
buf <- readIORef (buffer c)
unless (Seq.null buf) $ do
writeIORef (buffer c) Seq.empty
case sSendRecvTimeout (settings c) of
0 -> go buf
t -> withTimeout (timeouts c) t (abort c) (go buf)
where
go buf = do
send c (fmap fst buf)
bb <- readIORef (leftover c)
foldlM fetchResult bb buf >>= writeIORef (leftover c)
fetchResult b (cmd, ref) = do
(b', x) <- receiveWith c b
step cmd ref x
return b'
step (Array 1 [Bulk "MULTI"]) ref res = do
writeIORef ref res
modifyIORef' (trxState c) ((True, ) . snd)
step (Array 1 [Bulk "EXEC"]) ref res = case res of
Array _ xs -> do
refs <- reverse . snd <$> readIORef (trxState c)
mapM_ (uncurry writeIORef) (zip refs xs)
writeIORef (trxState c) (False, [])
writeIORef ref (Int 0)
NullArray -> throwIO TransactionAborted
NullBulk -> throwIO TransactionAborted
Err e -> throwIO (TransactionFailure $ show e)
_ -> throwIO (TransactionFailure "invalid response for exec")
step (Array 1 [Bulk "DISCARD"]) _ res = do
expect "DISCARD" "OK" res
throwIO TransactionDiscarded
step _ ref res = do
(inTrx, trxCmds) <- readIORef (trxState c)
if inTrx then do
expect "*" "QUEUED" res
writeIORef (trxState c) (inTrx, ref:trxCmds)
else
writeIORef ref res
abort :: Connection -> IO a
abort c = do
err (logger c) $ "connection.timeout" .= show c
close c
throwIO $ Timeout (show c)
send :: Connection -> Seq Resp -> IO ()
send c = sendMany (sock c) . concatMap (toChunks . encode) . toList
receive :: Connection -> IO Resp
receive c = do
bstr <- readIORef (leftover c)
(b, x) <- receiveWith c bstr
writeIORef (leftover c) b
return x
receiveWith :: Connection -> ByteString -> IO (ByteString, Resp)
receiveWith c b = do
res <- parseWith (recv (sock c) 4096) resp b
case res of
Fail _ _ m -> throwIO $ InternalError m
Partial _ -> throwIO $ InternalError "partial result"
Done b' x -> (b',) <$> errorCheck x
errorCheck :: Resp -> IO Resp
errorCheck (Err e) = throwIO $ RedisError e
errorCheck r = return r
fd :: Socket -> Int32
fd !s = let CInt !n = fdSocket s in n
expect :: String -> Char8.ByteString -> Resp -> IO ()
expect x y = void . either throwIO return . matchStr x y