-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}

module Database.Redis.IO.Connection
    ( Connection
    , settings
    , resolve
    , connect
    , close
    , request
    , sync
    , send
    , receive
    ) where

import Control.Applicative
import Control.Exception
import Control.Monad
import Data.Attoparsec.ByteString hiding (Result)
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toChunks)
import Data.Foldable (foldlM, toList)
import Data.IORef
import Data.Maybe (isJust)
import Data.Redis
import Data.Sequence (Seq, (|>))
import Data.Word
import Database.Redis.IO.Settings
import Database.Redis.IO.Types
import Database.Redis.IO.Timeouts (TimeoutManager, withTimeout)
import Network.Socket hiding (connect, close)
import Network.Socket.ByteString (recv, sendMany)
import System.Logger hiding (Settings, settings, close)
import System.Timeout
import Prelude

import qualified Data.ByteString.Lazy.Char8 as Char8
import qualified Data.Sequence as Seq
import qualified Network.Socket as S

data Connection = Connection
    { Connection -> Settings
settings :: !Settings
    , Connection -> Logger
logger   :: !Logger
    , Connection -> TimeoutManager
timeouts :: !TimeoutManager
    , Connection -> InetAddr
address  :: !InetAddr
    , Connection -> Socket
sock     :: !Socket
    , Connection -> IORef ByteString
leftover :: !(IORef ByteString)
    , Connection -> IORef (Seq (Resp, IORef Resp))
buffer   :: !(IORef (Seq (Resp, IORef Resp)))
    , Connection -> IORef (Bool, [IORef Resp])
trxState :: !(IORef (Bool, [IORef Resp]))
    }

instance Show Connection where
    show :: Connection -> String
show = ByteString -> String
Char8.unpack (ByteString -> String)
-> (Connection -> ByteString) -> Connection -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
eval (Builder -> ByteString)
-> (Connection -> Builder) -> Connection -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Builder
forall a. ToBytes a => a -> Builder
bytes

instance ToBytes Connection where
    bytes :: Connection -> Builder
bytes Connection
c = InetAddr -> Builder
forall a. ToBytes a => a -> Builder
bytes (Connection -> InetAddr
address Connection
c)

resolve :: String -> Word16 -> IO [InetAddr]
resolve :: String -> Word16 -> IO [InetAddr]
resolve String
host Word16
port =
    (AddrInfo -> InetAddr) -> [AddrInfo] -> [InetAddr]
forall a b. (a -> b) -> [a] -> [b]
map (SockAddr -> InetAddr
InetAddr (SockAddr -> InetAddr)
-> (AddrInfo -> SockAddr) -> AddrInfo -> InetAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> SockAddr
addrAddress) ([AddrInfo] -> [InetAddr]) -> IO [AddrInfo] -> IO [InetAddr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host) (String -> Maybe String
forall a. a -> Maybe a
Just (Word16 -> String
forall a. Show a => a -> String
show Word16
port))
  where
    hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG], addrSocketType :: SocketType
addrSocketType = SocketType
Stream }

connect :: Settings -> Logger -> TimeoutManager -> InetAddr -> IO Connection
connect :: Settings -> Logger -> TimeoutManager -> InetAddr -> IO Connection
connect Settings
t Logger
g TimeoutManager
m InetAddr
a = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
mkSock Socket -> IO ()
S.close ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Socket
s -> do
    Maybe ()
ok <- Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout (Milliseconds -> Int
ms (Settings -> Milliseconds
sConnectTimeout Settings
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) (Socket -> SockAddr -> IO ()
S.connect Socket
s (InetAddr -> SockAddr
sockAddr InetAddr
a))
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe () -> Bool
forall a. Maybe a -> Bool
isJust Maybe ()
ok) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        ConnectionError -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionError
ConnectTimeout
    Settings
-> Logger
-> TimeoutManager
-> InetAddr
-> Socket
-> IORef ByteString
-> IORef (Seq (Resp, IORef Resp))
-> IORef (Bool, [IORef Resp])
-> Connection
Connection Settings
t Logger
g TimeoutManager
m InetAddr
a Socket
s
        (IORef ByteString
 -> IORef (Seq (Resp, IORef Resp))
 -> IORef (Bool, [IORef Resp])
 -> Connection)
-> IO (IORef ByteString)
-> IO
     (IORef (Seq (Resp, IORef Resp))
      -> IORef (Bool, [IORef Resp]) -> Connection)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
""
        IO
  (IORef (Seq (Resp, IORef Resp))
   -> IORef (Bool, [IORef Resp]) -> Connection)
-> IO (IORef (Seq (Resp, IORef Resp)))
-> IO (IORef (Bool, [IORef Resp]) -> Connection)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Seq (Resp, IORef Resp) -> IO (IORef (Seq (Resp, IORef Resp)))
forall a. a -> IO (IORef a)
newIORef Seq (Resp, IORef Resp)
forall a. Seq a
Seq.empty
        IO (IORef (Bool, [IORef Resp]) -> Connection)
-> IO (IORef (Bool, [IORef Resp])) -> IO Connection
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Bool, [IORef Resp]) -> IO (IORef (Bool, [IORef Resp]))
forall a. a -> IO (IORef a)
newIORef (Bool
False, [])
  where
    mkSock :: IO Socket
mkSock = Family -> SocketType -> ProtocolNumber -> IO Socket
socket (SockAddr -> Family
familyOf (SockAddr -> Family) -> SockAddr -> Family
forall a b. (a -> b) -> a -> b
$ InetAddr -> SockAddr
sockAddr InetAddr
a) SocketType
Stream ProtocolNumber
defaultProtocol

    familyOf :: SockAddr -> Family
familyOf (SockAddrInet  PortNumber
_ HostAddress
_    ) = Family
AF_INET
    familyOf (SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
_ HostAddress
_) = Family
AF_INET6
    familyOf (SockAddrUnix  String
_      ) = Family
AF_UNIX
#if !MIN_VERSION_network(3,0,0)
    familyOf (SockAddrCan   _      ) = AF_CAN
#endif

close :: Connection -> IO ()
close :: Connection -> IO ()
close = Socket -> IO ()
S.close (Socket -> IO ()) -> (Connection -> Socket) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Socket
sock

request :: Resp -> IORef Resp -> Connection -> IO ()
request :: Resp -> IORef Resp -> Connection -> IO ()
request Resp
x IORef Resp
y Connection
c = IORef (Seq (Resp, IORef Resp))
-> (Seq (Resp, IORef Resp) -> Seq (Resp, IORef Resp)) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c) (Seq (Resp, IORef Resp)
-> (Resp, IORef Resp) -> Seq (Resp, IORef Resp)
forall a. Seq a -> a -> Seq a
|> (Resp
x, IORef Resp
y))

sync :: Connection -> IO ()
sync :: Connection -> IO ()
sync Connection
c = do
    Seq (Resp, IORef Resp)
buf <- IORef (Seq (Resp, IORef Resp)) -> IO (Seq (Resp, IORef Resp))
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Seq (Resp, IORef Resp) -> Bool
forall a. Seq a -> Bool
Seq.null Seq (Resp, IORef Resp)
buf) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        IORef (Seq (Resp, IORef Resp)) -> Seq (Resp, IORef Resp) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Seq (Resp, IORef Resp))
buffer Connection
c) Seq (Resp, IORef Resp)
forall a. Seq a
Seq.empty
        case Settings -> Milliseconds
sSendRecvTimeout (Connection -> Settings
settings Connection
c) of
            Milliseconds
0 -> Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf
            Milliseconds
t -> TimeoutManager -> Milliseconds -> IO () -> IO () -> IO ()
forall a. TimeoutManager -> Milliseconds -> IO () -> IO a -> IO a
withTimeout (Connection -> TimeoutManager
timeouts Connection
c) Milliseconds
t (Connection -> IO ()
forall a. Connection -> IO a
abort Connection
c) (Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf)
  where
    go :: Seq (Resp, IORef Resp) -> IO ()
go Seq (Resp, IORef Resp)
buf = do
        Connection -> Seq Resp -> IO ()
send Connection
c (((Resp, IORef Resp) -> Resp) -> Seq (Resp, IORef Resp) -> Seq Resp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Resp, IORef Resp) -> Resp
forall a b. (a, b) -> a
fst Seq (Resp, IORef Resp)
buf)
        ByteString
bb <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef (Connection -> IORef ByteString
leftover Connection
c)
        (ByteString -> (Resp, IORef Resp) -> IO ByteString)
-> ByteString -> Seq (Resp, IORef Resp) -> IO ByteString
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ByteString -> (Resp, IORef Resp) -> IO ByteString
fetchResult ByteString
bb Seq (Resp, IORef Resp)
buf IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef ByteString
leftover Connection
c)

    fetchResult :: ByteString -> (Resp, IORef Resp) -> IO ByteString
fetchResult ByteString
b (Resp
cmd, IORef Resp
ref) = do
        (ByteString
b', Resp
x) <- Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
b
        Resp -> IORef Resp -> Resp -> IO ()
step Resp
cmd IORef Resp
ref Resp
x
        ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
b'

    step :: Resp -> IORef Resp -> Resp -> IO ()
step (Array Int
1 [Bulk ByteString
"MULTI"]) IORef Resp
ref Resp
res = do
        IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref Resp
res
        IORef (Bool, [IORef Resp])
-> ((Bool, [IORef Resp]) -> (Bool, [IORef Resp])) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) ((Bool
True, ) ([IORef Resp] -> (Bool, [IORef Resp]))
-> ((Bool, [IORef Resp]) -> [IORef Resp])
-> (Bool, [IORef Resp])
-> (Bool, [IORef Resp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, [IORef Resp]) -> [IORef Resp]
forall a b. (a, b) -> b
snd)

    step (Array Int
1 [Bulk ByteString
"EXEC"]) IORef Resp
ref Resp
res = case Resp
res of
        Array Int
_ [Resp]
xs -> do
            [IORef Resp]
refs <- [IORef Resp] -> [IORef Resp]
forall a. [a] -> [a]
reverse ([IORef Resp] -> [IORef Resp])
-> ((Bool, [IORef Resp]) -> [IORef Resp])
-> (Bool, [IORef Resp])
-> [IORef Resp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, [IORef Resp]) -> [IORef Resp]
forall a b. (a, b) -> b
snd ((Bool, [IORef Resp]) -> [IORef Resp])
-> IO (Bool, [IORef Resp]) -> IO [IORef Resp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Bool, [IORef Resp]) -> IO (Bool, [IORef Resp])
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c)
            ((IORef Resp, Resp) -> IO ()) -> [(IORef Resp, Resp)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((IORef Resp -> Resp -> IO ()) -> (IORef Resp, Resp) -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef) ([IORef Resp] -> [Resp] -> [(IORef Resp, Resp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [IORef Resp]
refs [Resp]
xs)
            IORef (Bool, [IORef Resp]) -> (Bool, [IORef Resp]) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) (Bool
False, [])
            IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref (Int64 -> Resp
Int Int64
0)
        Resp
NullArray  -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionAborted
        Resp
NullBulk   -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionAborted
        Err ByteString
e      -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> TransactionFailure
TransactionFailure (String -> TransactionFailure) -> String -> TransactionFailure
forall a b. (a -> b) -> a -> b
$ ByteString -> String
forall a. Show a => a -> String
show ByteString
e)
        Resp
_          -> TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> TransactionFailure
TransactionFailure String
"invalid response for exec")

    step (Array Int
1 [Bulk ByteString
"DISCARD"]) IORef Resp
_ Resp
res = do
        String -> ByteString -> Resp -> IO ()
expect String
"DISCARD" ByteString
"OK" Resp
res
        TransactionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO TransactionFailure
TransactionDiscarded

    step Resp
_ IORef Resp
ref Resp
res = do
        (Bool
inTrx, [IORef Resp]
trxCmds) <- IORef (Bool, [IORef Resp]) -> IO (Bool, [IORef Resp])
forall a. IORef a -> IO a
readIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c)
        if Bool
inTrx then do
            String -> ByteString -> Resp -> IO ()
expect String
"*" ByteString
"QUEUED" Resp
res
            IORef (Bool, [IORef Resp]) -> (Bool, [IORef Resp]) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef (Bool, [IORef Resp])
trxState Connection
c) (Bool
inTrx, IORef Resp
refIORef Resp -> [IORef Resp] -> [IORef Resp]
forall a. a -> [a] -> [a]
:[IORef Resp]
trxCmds)
        else
            IORef Resp -> Resp -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Resp
ref Resp
res

abort :: Connection -> IO a
abort :: Connection -> IO a
abort Connection
c = do
    Logger -> (Msg -> Msg) -> IO ()
forall (m :: * -> *). MonadIO m => Logger -> (Msg -> Msg) -> m ()
err (Connection -> Logger
logger Connection
c) ((Msg -> Msg) -> IO ()) -> (Msg -> Msg) -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString
"connection.timeout" ByteString -> String -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
.= Connection -> String
forall a. Show a => a -> String
show Connection
c
    Connection -> IO ()
close Connection
c
    Timeout -> IO a
forall e a. Exception e => e -> IO a
throwIO (Timeout -> IO a) -> Timeout -> IO a
forall a b. (a -> b) -> a -> b
$ String -> Timeout
Timeout (Connection -> String
forall a. Show a => a -> String
show Connection
c)

send :: Connection -> Seq Resp -> IO ()
send :: Connection -> Seq Resp -> IO ()
send Connection
c = Socket -> [ByteString] -> IO ()
sendMany (Connection -> Socket
sock Connection
c) ([ByteString] -> IO ())
-> (Seq Resp -> [ByteString]) -> Seq Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Resp -> [ByteString]) -> [Resp] -> [ByteString]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ByteString -> [ByteString]
toChunks (ByteString -> [ByteString])
-> (Resp -> ByteString) -> Resp -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Resp -> ByteString
encode) ([Resp] -> [ByteString])
-> (Seq Resp -> [Resp]) -> Seq Resp -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq Resp -> [Resp]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList

receive :: Connection -> IO Resp
receive :: Connection -> IO Resp
receive Connection
c = do
    ByteString
bstr   <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef (Connection -> IORef ByteString
leftover Connection
c)
    (ByteString
b, Resp
x) <- Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
bstr
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef ByteString
leftover Connection
c) ByteString
b
    Resp -> IO Resp
forall (m :: * -> *) a. Monad m => a -> m a
return Resp
x

receiveWith :: Connection -> ByteString -> IO (ByteString, Resp)
receiveWith :: Connection -> ByteString -> IO (ByteString, Resp)
receiveWith Connection
c ByteString
b = do
    Result Resp
res <- IO ByteString -> Parser Resp -> ByteString -> IO (Result Resp)
forall (m :: * -> *) a.
Monad m =>
m ByteString -> Parser a -> ByteString -> m (Result a)
parseWith (Socket -> Int -> IO ByteString
recv (Connection -> Socket
sock Connection
c) Int
4096) Parser Resp
resp ByteString
b
    case Result Resp
res of
        Fail    ByteString
_  [String]
_ String
m -> InternalError -> IO (ByteString, Resp)
forall e a. Exception e => e -> IO a
throwIO (InternalError -> IO (ByteString, Resp))
-> InternalError -> IO (ByteString, Resp)
forall a b. (a -> b) -> a -> b
$ String -> InternalError
InternalError String
m
        Partial ByteString -> Result Resp
_      -> InternalError -> IO (ByteString, Resp)
forall e a. Exception e => e -> IO a
throwIO (InternalError -> IO (ByteString, Resp))
-> InternalError -> IO (ByteString, Resp)
forall a b. (a -> b) -> a -> b
$ String -> InternalError
InternalError String
"partial result"
        Done    ByteString
b'   Resp
x -> (ByteString
b',) (Resp -> (ByteString, Resp)) -> IO Resp -> IO (ByteString, Resp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Resp -> IO Resp
errorCheck Resp
x

errorCheck :: Resp -> IO Resp
errorCheck :: Resp -> IO Resp
errorCheck (Err ByteString
e) = RedisError -> IO Resp
forall e a. Exception e => e -> IO a
throwIO (RedisError -> IO Resp) -> RedisError -> IO Resp
forall a b. (a -> b) -> a -> b
$ ByteString -> RedisError
RedisError ByteString
e
errorCheck Resp
r       = Resp -> IO Resp
forall (m :: * -> *) a. Monad m => a -> m a
return Resp
r

-- Helpers:

expect :: String -> Char8.ByteString -> Resp -> IO ()
expect :: String -> ByteString -> Resp -> IO ()
expect String
x ByteString
y = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (Resp -> IO ()) -> Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RedisError -> IO ())
-> (() -> IO ()) -> Either RedisError () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either RedisError -> IO ()
forall e a. Exception e => e -> IO a
throwIO () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either RedisError () -> IO ())
-> (Resp -> Either RedisError ()) -> Resp -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString -> Resp -> Either RedisError ()
matchStr String
x ByteString
y