{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Database.Redis.IO.Connection
( Connection
, settings
, resolve
, connect
, close
, request
, sync
, send
, receive
) where
import Control.Applicative
import Control.Exception
import Control.Monad
import Data.Attoparsec.ByteString hiding (Result)
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toChunks)
import Data.Foldable (foldlM, toList)
import Data.IORef
import Data.Maybe (isJust)
import Data.Redis
import Data.Sequence (Seq, (|>))
import Data.Word
import Database.Redis.IO.Settings
import Database.Redis.IO.Types
import Database.Redis.IO.Timeouts (TimeoutManager, withTimeout)
import Network.Socket hiding (connect, close)
import Network.Socket.ByteString (recv, sendMany)
import System.Logger hiding (Settings, settings, close)
import System.Timeout
import Prelude
import qualified Data.ByteString.Lazy.Char8 as Char8
import qualified Data.Sequence as Seq
import qualified Network.Socket as S
data Connection = Connection
{ Connection -> Settings
settings :: !Settings
, Connection -> Logger
logger :: !Logger
, Connection -> TimeoutManager
timeouts :: !TimeoutManager
, Connection -> InetAddr
address :: !InetAddr
, Connection -> Socket
sock :: !Socket
, Connection -> IORef ByteString
leftover :: !(IORef ByteString)
, Connection -> IORef (Seq (Resp, IORef Resp))
buffer :: !(IORef (Seq (Resp, IORef Resp)))
, Connection -> IORef (Bool, [IORef Resp])
trxState :: !(IORef (Bool, [IORef Resp]))
}
instance Show Connection where
show :: Connection -> String
show = ByteString -> String
Char8.unpack (ByteString -> String)
-> (Connection -> ByteString) -> Connection -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
eval (Builder -> ByteString)
-> (Connection -> Builder) -> Connection -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Builder
forall a. ToBytes a => a -> Builder
bytes
instance ToBytes Connection where
bytes :: Connection -> Builder
bytes Connection
c = InetAddr -> Builder
forall a. ToBytes a => a -> Builder
bytes (Connection -> InetAddr
address Connection
c)
resolve :: String -> Word16 -> IO [InetAddr]
resolve :: String -> Word16 -> IO [InetAddr]
resolve String
host Word16
port =
(AddrInfo -> InetAddr) -> [AddrInfo] -> [InetAddr]
forall a b. (a -> b) -> [a] -> [b]
map (SockAddr -> InetAddr
InetAddr (SockAddr -> InetAddr)
-> (AddrInfo -> SockAddr) -> AddrInfo -> InetAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> SockAddr
addrAddress) ([AddrInfo] -> [InetAddr]) -> IO [AddrInfo] -> IO [InetAddr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host) (String -> Maybe String
forall a. a -> Maybe a
Just (Word16 -> String
forall a. Show a => a -> String
show Word16
port))
where
hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG], addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
connect :: Settings -> Logger -> TimeoutManager -> InetAddr -> IO Connection
connect :: Settings -> Logger -> TimeoutManager -> InetAddr -> IO Connection
connect Settings
t Logger
g TimeoutManager
m InetAddr
a = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
mkSock Socket -> IO ()
S.close ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Socket
s -> do
Maybe ()
ok <- Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout (Milliseconds -> Int
ms (Settings -> Milliseconds
sConnectTimeout Settings
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) (Socket -> SockAddr -> IO ()
S.connect Socket
s (InetAddr -> SockAddr
sockAddr InetAddr
a))
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe () -> Bool
forall a. Maybe a -> Bool
isJust Maybe ()
ok) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
ConnectionError -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionError
ConnectTimeout
Settings
-> Logger
-> TimeoutManager
-> InetAddr
-> Socket
-> IORef ByteString
-> IORef (Seq (Resp, IORef Resp))
-> IORef (Bool, [IORef Resp])
-> Connection
Connection Settings
t Logger
g TimeoutManager
m InetAddr
a Socket
s
(IORef ByteString
-> IORef (Seq (Resp, IORef Resp))
-> IORef (Bool, [IORef Resp])
-> Connection)
-> IO (IORef ByteString)
-> IO
(IORef (Seq (Resp, IORef Resp))
-> IORef (Bool, [IORef Resp]) -> Connection)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
""
IO
(IORef (Seq (Resp, IORef Resp))
-> IORef (Bool, [IORef Resp]) -> Connection)
-> IO (IORef (Seq (Resp, IORef Resp)))
-> IO (IORef (Bool, [IORef Resp]) -> Connection)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Seq (Resp, IORef Resp) -> IO (IORef (Seq (Resp, IORef Resp)))
forall a. a -> IO (IORef a)
newIORef Seq (Resp, IORef Resp)
forall a. Seq a
Seq.empty
IO (IORef (Bool, [IORef Resp]) -> Connection)
-> IO (IORef (Bool, [IORef Resp])) -> IO Connection
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Bool, [IORef Resp]) -> IO (IORef (Bool, [IORef Resp]))
forall a. a -> IO (IORef a)
newIORef (Bool
False, [])
where
mkSock :: IO Socket
mkSock = Family -> SocketType -> ProtocolNumber -> IO Socket
socket (SockAddr -> Family
familyOf (SockAddr -> Family) -> SockAddr -> Family
forall a b. (a -> b) -> a -> b
$ InetAddr -> SockAddr
sockAddr InetAddr
a) SocketType
Stream ProtocolNumber
defaultProtocol
familyOf :: SockAddr -> Family
familyOf (SockAddrInet PortNumber
_ HostAddress
_ ) = Family
AF_INET
familyOf (SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
_ HostAddress
_) = Family
AF_INET6
familyOf (SockAddrUnix String
_ ) = Family
AF_UNIX
#if !MIN_VERSION_network(3,0,0)
familyOf (SockAddrCan _ ) = AF_CAN
#endif
close :: Connection -> IO ()
close :: Connection -> IO ()
close = Socket -> IO ()
S.close (Socket -> IO ()) -> (Connection -> Socket) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Socket
sock
request :: Resp -> IORef Resp -> Connection -> IO ()
request :: Resp -> IORef Resp -> Connection -> IO ()
request Resp
x IORef Resp
y Connection
c = IORef (Seq (Resp, IORef Resp))
-> (Seq (Resp, IORef Resp) -> Seq (Resp, IORef Resp)) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c) (Seq (Resp, IORef Resp)
-> (Resp, IORef Resp) -> Seq (Resp, IORef Resp)
forall a. Seq a -> a -> Seq a
|> (Resp
x, IORef Resp
y))
sync :: Connection -> IO ()
sync :: Connection -> IO ()
sync Connection
c = do
Seq (Resp, IORef Resp)
buf <- IORef (Seq (Resp, IORef Resp)) -> IO (Seq (Resp, IORef Resp))
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c)
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Seq (Resp, IORef Resp) -> Bool
forall a. Seq a -> Bool
Seq.null Seq (Resp, IORef Resp)
buf) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IORef (Seq (Resp, IORef Resp)) -> Seq (Resp, IORef Resp) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c) Seq (Resp, IORef Resp)
forall a. Seq a
Seq.empty
case Settings -> Milliseconds
sSendRecvTimeout (Connection -> Settings
settings Connection
c) of
Milliseconds
0 -> Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf
Milliseconds
t -> TimeoutManager -> Milliseconds -> IO () -> IO () -> IO ()
forall a. TimeoutManager -> Milliseconds -> IO () -> IO a -> IO a
withTimeout (Connection -> TimeoutManager
timeouts Connection
c) Milliseconds
t (Connection -> IO ()
forall a. Connection -> IO a
abort Connection
c) (Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf)
where
go :: Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf = do
Connection -> Seq Resp -> IO ()
send Connection
c (((Resp, IORef Resp) -> Resp) -> Seq (Resp, IORef Resp) -> Seq Resp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Resp, IORef Resp) -> Resp
forall a b. (a, b) -> a
fst Seq (Resp, IORef Resp)
buf)
ByteString
bb <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef (Connection -> IORef ByteString
leftover Connection
c)
(ByteString -> (Resp, IORef Resp) -> IO ByteString)
-> ByteString -> Seq (Resp, IORef Resp) -> IO ByteString
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ByteString -> (Resp, IORef Resp) -> IO ByteString
fetchResult ByteString
bb Seq (Resp, IORef Resp)
buf IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef ByteString
leftover Connection
c)
fetchResult :: ByteString -> (Resp, IORef Resp) -> IO ByteString
fetchResult ByteString
b (Resp
cmd, IORef Resp
ref) = do
(ByteString
b', Resp
x) <- Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
b
Resp -> IORef Resp -> Resp -> IO ()
step Resp
cmd IORef Resp
ref Resp
x
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
b'
step :: Resp -> IORef Resp -> Resp -> IO ()
step (Array Int
1 [Bulk ByteString
"MULTI"]) IORef Resp
ref Resp
res = do
IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref Resp
res
IORef (Bool, [IORef Resp])
-> ((Bool, [IORef Resp]) -> (Bool, [IORef Resp])) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) ((Bool
True, ) ([IORef Resp] -> (Bool, [IORef Resp]))
-> ((Bool, [IORef Resp]) -> [IORef Resp])
-> (Bool, [IORef Resp])
-> (Bool, [IORef Resp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, [IORef Resp]) -> [IORef Resp]
forall a b. (a, b) -> b
snd)
step (Array Int
1 [Bulk ByteString
"EXEC"]) IORef Resp
ref Resp
res = case Resp
res of
Array Int
_ [Resp]
xs -> do
[IORef Resp]
refs <- [IORef Resp] -> [IORef Resp]
forall a. [a] -> [a]
reverse ([IORef Resp] -> [IORef Resp])
-> ((Bool, [IORef Resp]) -> [IORef Resp])
-> (Bool, [IORef Resp])
-> [IORef Resp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, [IORef Resp]) -> [IORef Resp]
forall a b. (a, b) -> b
snd ((Bool, [IORef Resp]) -> [IORef Resp])
-> IO (Bool, [IORef Resp]) -> IO [IORef Resp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Bool, [IORef Resp]) -> IO (Bool, [IORef Resp])
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c)
((IORef Resp, Resp) -> IO ()) -> [(IORef Resp, Resp)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((IORef Resp -> Resp -> IO ()) -> (IORef Resp, Resp) -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef) ([IORef Resp] -> [Resp] -> [(IORef Resp, Resp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [IORef Resp]
refs [Resp]
xs)
IORef (Bool, [IORef Resp]) -> (Bool, [IORef Resp]) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) (Bool
False, [])
IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref (Int64 -> Resp
Int Int64
0)
Resp
NullArray -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionAborted
Resp
NullBulk -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionAborted
Err ByteString
e -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> TransactionFailure
TransactionFailure (String -> TransactionFailure) -> String -> TransactionFailure
forall a b. (a -> b) -> a -> b
$ ByteString -> String
forall a. Show a => a -> String
show ByteString
e)
Resp
_ -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> TransactionFailure
TransactionFailure String
"invalid response for exec")
step (Array Int
1 [Bulk ByteString
"DISCARD"]) IORef Resp
_ Resp
res = do
String -> ByteString -> Resp -> IO ()
expect String
"DISCARD" ByteString
"OK" Resp
res
TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionDiscarded
step Resp
_ IORef Resp
ref Resp
res = do
(Bool
inTrx, [IORef Resp]
trxCmds) <- IORef (Bool, [IORef Resp]) -> IO (Bool, [IORef Resp])
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c)
if Bool
inTrx then do
String -> ByteString -> Resp -> IO ()
expect String
"*" ByteString
"QUEUED" Resp
res
IORef (Bool, [IORef Resp]) -> (Bool, [IORef Resp]) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) (Bool
inTrx, IORef Resp
refIORef Resp -> [IORef Resp] -> [IORef Resp]
forall a. a -> [a] -> [a]
:[IORef Resp]
trxCmds)
else
IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref Resp
res
abort :: Connection -> IO a
abort :: Connection -> IO a
abort Connection
c = do
Logger -> (Msg -> Msg) -> IO ()
forall (m :: * -> *). MonadIO m => Logger -> (Msg -> Msg) -> m ()
err (Connection -> Logger
logger Connection
c) ((Msg -> Msg) -> IO ()) -> (Msg -> Msg) -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString
"connection.timeout" ByteString -> String -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
.= Connection -> String
forall a. Show a => a -> String
show Connection
c
Connection -> IO ()
close Connection
c
Timeout -> IO a
forall e a. Exception e => e -> IO a
throwIO (Timeout -> IO a) -> Timeout -> IO a
forall a b. (a -> b) -> a -> b
$ String -> Timeout
Timeout (Connection -> String
forall a. Show a => a -> String
show Connection
c)
send :: Connection -> Seq Resp -> IO ()
send :: Connection -> Seq Resp -> IO ()
send Connection
c = Socket -> [ByteString] -> IO ()
sendMany (Connection -> Socket
sock Connection
c) ([ByteString] -> IO ())
-> (Seq Resp -> [ByteString]) -> Seq Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Resp -> [ByteString]) -> [Resp] -> [ByteString]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ByteString -> [ByteString]
toChunks (ByteString -> [ByteString])
-> (Resp -> ByteString) -> Resp -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Resp -> ByteString
encode) ([Resp] -> [ByteString])
-> (Seq Resp -> [Resp]) -> Seq Resp -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq Resp -> [Resp]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
receive :: Connection -> IO Resp
receive :: Connection -> IO Resp
receive Connection
c = do
ByteString
bstr <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef (Connection -> IORef ByteString
leftover Connection
c)
(ByteString
b, Resp
x) <- Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
bstr
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef ByteString
leftover Connection
c) ByteString
b
Resp -> IO Resp
forall (m :: * -> *) a. Monad m => a -> m a
return Resp
x
receiveWith :: Connection -> ByteString -> IO (ByteString, Resp)
receiveWith :: Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
b = do
Result Resp
res <- IO ByteString -> Parser Resp -> ByteString -> IO (Result Resp)
forall (m :: * -> *) a.
Monad m =>
m ByteString -> Parser a -> ByteString -> m (Result a)
parseWith (Socket -> Int -> IO ByteString
recv (Connection -> Socket
sock Connection
c) Int
4096) Parser Resp
resp ByteString
b
case Result Resp
res of
Fail ByteString
_ [String]
_ String
m -> InternalError -> IO (ByteString, Resp)
forall e a. Exception e => e -> IO a
throwIO (InternalError -> IO (ByteString, Resp))
-> InternalError -> IO (ByteString, Resp)
forall a b. (a -> b) -> a -> b
$ String -> InternalError
InternalError String
m
Partial ByteString -> Result Resp
_ -> InternalError -> IO (ByteString, Resp)
forall e a. Exception e => e -> IO a
throwIO (InternalError -> IO (ByteString, Resp))
-> InternalError -> IO (ByteString, Resp)
forall a b. (a -> b) -> a -> b
$ String -> InternalError
InternalError String
"partial result"
Done ByteString
b' Resp
x -> (ByteString
b',) (Resp -> (ByteString, Resp)) -> IO Resp -> IO (ByteString, Resp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Resp -> IO Resp
errorCheck Resp
x
errorCheck :: Resp -> IO Resp
errorCheck :: Resp -> IO Resp
errorCheck (Err ByteString
e) = RedisError -> IO Resp
forall e a. Exception e => e -> IO a
throwIO (RedisError -> IO Resp) -> RedisError -> IO Resp
forall a b. (a -> b) -> a -> b
$ ByteString -> RedisError
RedisError ByteString
e
errorCheck Resp
r = Resp -> IO Resp
forall (m :: * -> *) a. Monad m => a -> m a
return Resp
r
expect :: String -> Char8.ByteString -> Resp -> IO ()
expect :: String -> ByteString -> Resp -> IO ()
expect String
x ByteString
y = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (Resp -> IO ()) -> Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RedisError -> IO ())
-> (() -> IO ()) -> Either RedisError () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either RedisError -> IO ()
forall e a. Exception e => e -> IO a
throwIO () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either RedisError () -> IO ())
-> (Resp -> Either RedisError ()) -> Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString -> Resp -> Either RedisError ()
matchStr String
x ByteString
y