{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
module Database.Franz.Client.Reconnect
  ( Pool
  , poolLogFunc
  , poolRetryPolicy
  , withPool
  , withReconnection
  , Reconnect(..)
  , atomicallyReconnecting
  , fetchWithPool
  )
  where

import Control.Retry (recovering, RetryPolicyM)
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Exception (IOException)
import Control.Monad.Catch
import Database.Franz.Client

data Pool = Pool
  { Pool -> FranzPath
poolPath :: FranzPath
  , Pool -> MVar (Int, Maybe Connection)
poolRef :: MVar (Int {- Connection number -}, Maybe Connection)
  , Pool -> RetryPolicyM IO
poolRetryPolicy :: RetryPolicyM IO
  , Pool -> String -> IO ()
poolLogFunc :: String -> IO ()
  }

-- | A wrapper of 'fetch' which calls 'withReconnection' internally
fetchWithPool
  :: Pool
  -> Query
  -> (STM Response -> IO r)
  -> IO r
fetchWithPool :: Pool -> Query -> (STM Response -> IO r) -> IO r
fetchWithPool Pool
pool Query
q STM Response -> IO r
cont = Pool -> (Connection -> IO r) -> IO r
forall a. Pool -> (Connection -> IO a) -> IO a
withReconnection Pool
pool ((Connection -> IO r) -> IO r) -> (Connection -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> Connection -> Query -> (STM Response -> IO r) -> IO r
forall r. Connection -> Query -> (STM Response -> IO r) -> IO r
fetch Connection
conn Query
q STM Response -> IO r
cont
  IO r -> (Reconnect -> IO r) -> IO r
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \case
    r :: Reconnect
r@(ReconnectInQuery Query
_ Reconnect
_) -> Reconnect -> IO r
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Reconnect
r -- Avoid deeply nested ReconnectInQuery
    Reconnect
r -> Reconnect -> IO r
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (Reconnect -> IO r) -> Reconnect -> IO r
forall a b. (a -> b) -> a -> b
$ Query -> Reconnect -> Reconnect
ReconnectInQuery Query
q Reconnect
r

-- | Run an action which takes a 'Connection', reconnecting whenever it throws an exception.
withReconnection :: Pool -> (Connection -> IO a) -> IO a
withReconnection :: Pool -> (Connection -> IO a) -> IO a
withReconnection Pool{MVar (Int, Maybe Connection)
RetryPolicyM IO
FranzPath
String -> IO ()
poolLogFunc :: String -> IO ()
poolRetryPolicy :: RetryPolicyM IO
poolRef :: MVar (Int, Maybe Connection)
poolPath :: FranzPath
poolRef :: Pool -> MVar (Int, Maybe Connection)
poolPath :: Pool -> FranzPath
poolRetryPolicy :: Pool -> RetryPolicyM IO
poolLogFunc :: Pool -> String -> IO ()
..} Connection -> IO a
cont = RetryPolicyM IO
-> [RetryStatus -> Handler IO Bool]
-> (RetryStatus -> IO a)
-> IO a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
RetryPolicyM m
-> [RetryStatus -> Handler m Bool] -> (RetryStatus -> m a) -> m a
recovering
  RetryPolicyM IO
poolRetryPolicy
  [Handler IO Bool -> RetryStatus -> Handler IO Bool
forall a b. a -> b -> a
const (Handler IO Bool -> RetryStatus -> Handler IO Bool)
-> Handler IO Bool -> RetryStatus -> Handler IO Bool
forall a b. (a -> b) -> a -> b
$ (Reconnect -> IO Bool) -> Handler IO Bool
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((Reconnect -> IO Bool) -> Handler IO Bool)
-> (Reconnect -> IO Bool) -> Handler IO Bool
forall a b. (a -> b) -> a -> b
$ \(Reconnect
_ :: Reconnect) -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True]
  RetryStatus -> IO a
forall p. p -> IO a
body
  where

    handler :: SomeException -> Maybe String
handler SomeException
ex
      | Just (ClientError String
err) <- SomeException -> Maybe FranzException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex = String -> Maybe String
forall a. a -> Maybe a
Just String
err
      | Just IOException
e <- SomeException -> Maybe IOException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex = String -> Maybe String
forall a. a -> Maybe a
Just (IOException -> String
forall a. Show a => a -> String
show (IOException
e :: IOException))
      | Just (Reconnect
e :: Reconnect) <- SomeException -> Maybe Reconnect
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex = String -> Maybe String
forall a. a -> Maybe a
Just
          (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ String
"Reconnecting to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> FranzPath -> String
forall a. (Monoid a, IsString a) => FranzPath -> a
fromFranzPath FranzPath
poolPath String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" due to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Reconnect -> String
forall a. Show a => a -> String
show Reconnect
e
      | Bool
otherwise = Maybe String
forall a. Maybe a
Nothing

    body :: p -> IO a
body p
_ = do
      (Int
i, Connection
conn) <- MVar (Int, Maybe Connection)
-> ((Int, Maybe Connection)
    -> IO ((Int, Maybe Connection), (Int, Connection)))
-> IO (Int, Connection)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Int, Maybe Connection)
poolRef (((Int, Maybe Connection)
  -> IO ((Int, Maybe Connection), (Int, Connection)))
 -> IO (Int, Connection))
-> ((Int, Maybe Connection)
    -> IO ((Int, Maybe Connection), (Int, Connection)))
-> IO (Int, Connection)
forall a b. (a -> b) -> a -> b
$ \case
        (Int
i, Maybe Connection
Nothing) -> do
            String -> IO ()
poolLogFunc (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
                [ String
"Connnecting to"
                , FranzPath -> String
forall a. (Monoid a, IsString a) => FranzPath -> a
fromFranzPath FranzPath
poolPath
                ]
            Connection
conn <- (SomeException -> Maybe String)
-> IO Connection -> IO (Either String Connection)
forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> m a -> m (Either b a)
tryJust SomeException -> Maybe String
handler (FranzPath -> IO Connection
connect FranzPath
poolPath)
                IO (Either String Connection)
-> (Either String Connection -> IO Connection) -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (String -> IO Connection)
-> (Connection -> IO Connection)
-> Either String Connection
-> IO Connection
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\String
e -> String -> IO ()
poolLogFunc String
e IO () -> IO Connection -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Reconnect -> IO Connection
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Reconnect
ReconnectByError) Connection -> IO Connection
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            String -> IO ()
poolLogFunc (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Connection #" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" established"
            ((Int, Maybe Connection), (Int, Connection))
-> IO ((Int, Maybe Connection), (Int, Connection))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int
i, Connection -> Maybe Connection
forall a. a -> Maybe a
Just Connection
conn), (Int
i, Connection
conn))
        v :: (Int, Maybe Connection)
v@(Int
i, Just Connection
c) -> ((Int, Maybe Connection), (Int, Connection))
-> IO ((Int, Maybe Connection), (Int, Connection))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, Maybe Connection)
v, (Int
i, Connection
c))

      (SomeException -> Maybe String) -> IO a -> IO (Either String a)
forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> m a -> m (Either b a)
tryJust SomeException -> Maybe String
handler (Connection -> IO a
cont Connection
conn) IO (Either String a) -> (Either String a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Right a
a -> a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
        Left String
msg -> do
            String -> IO ()
poolLogFunc String
msg
            MVar (Int, Maybe Connection)
-> ((Int, Maybe Connection) -> IO (Int, Maybe Connection)) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Int, Maybe Connection)
poolRef (((Int, Maybe Connection) -> IO (Int, Maybe Connection)) -> IO ())
-> ((Int, Maybe Connection) -> IO (Int, Maybe Connection)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \case
                -- Don't disconnect if the sequential number is different;
                -- another thread already established a new connection
                (Int
j, Just Connection
_) | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Maybe Connection
forall a. Maybe a
Nothing) (Int, Maybe Connection) -> IO () -> IO (Int, Maybe Connection)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Connection -> IO ()
disconnect Connection
conn
                (Int, Maybe Connection)
x -> (Int, Maybe Connection) -> IO (Int, Maybe Connection)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int, Maybe Connection)
x
            Reconnect -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Reconnect
ReconnectByError

data Reconnect = ReconnectByTimeout
  | ReconnectByError
  | ReconnectInQuery !Query !Reconnect
  deriving (Int -> Reconnect -> String -> String
[Reconnect] -> String -> String
Reconnect -> String
(Int -> Reconnect -> String -> String)
-> (Reconnect -> String)
-> ([Reconnect] -> String -> String)
-> Show Reconnect
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [Reconnect] -> String -> String
$cshowList :: [Reconnect] -> String -> String
show :: Reconnect -> String
$cshow :: Reconnect -> String
showsPrec :: Int -> Reconnect -> String -> String
$cshowsPrec :: Int -> Reconnect -> String -> String
Show, Reconnect -> Reconnect -> Bool
(Reconnect -> Reconnect -> Bool)
-> (Reconnect -> Reconnect -> Bool) -> Eq Reconnect
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reconnect -> Reconnect -> Bool
$c/= :: Reconnect -> Reconnect -> Bool
== :: Reconnect -> Reconnect -> Bool
$c== :: Reconnect -> Reconnect -> Bool
Eq)
instance Exception Reconnect

withPool :: RetryPolicyM IO
    -> (String -> IO ()) -- ^ diagnostic output
    -> FranzPath
    -> (Pool -> IO a)
    -> IO a
withPool :: RetryPolicyM IO
-> (String -> IO ()) -> FranzPath -> (Pool -> IO a) -> IO a
withPool RetryPolicyM IO
poolRetryPolicy String -> IO ()
poolLogFunc FranzPath
poolPath Pool -> IO a
cont = do
  MVar (Int, Maybe Connection)
poolRef <- (Int, Maybe Connection) -> IO (MVar (Int, Maybe Connection))
forall a. a -> IO (MVar a)
newMVar (Int
0, Maybe Connection
forall a. Maybe a
Nothing)
  Pool -> IO a
cont Pool :: FranzPath
-> MVar (Int, Maybe Connection)
-> RetryPolicyM IO
-> (String -> IO ())
-> Pool
Pool{MVar (Int, Maybe Connection)
RetryPolicyM IO
FranzPath
String -> IO ()
poolRef :: MVar (Int, Maybe Connection)
poolPath :: FranzPath
poolLogFunc :: String -> IO ()
poolRetryPolicy :: RetryPolicyM IO
poolRef :: MVar (Int, Maybe Connection)
poolPath :: FranzPath
poolRetryPolicy :: RetryPolicyM IO
poolLogFunc :: String -> IO ()
..} IO a -> IO () -> IO a
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
`finally` do
    (Int
_, Maybe Connection
conn) <- MVar (Int, Maybe Connection) -> IO (Int, Maybe Connection)
forall a. MVar a -> IO a
takeMVar MVar (Int, Maybe Connection)
poolRef
    (Connection -> IO ()) -> Maybe Connection -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Connection -> IO ()
disconnect Maybe Connection
conn

-- | Run an 'STM' action, throwing 'Reconnect' when it exceeds the given timeout.
atomicallyReconnecting :: Int -- ^ timeout in microseconds
    -> STM a -> IO a
atomicallyReconnecting :: Int -> STM a -> IO a
atomicallyReconnecting Int
timeout STM a
m = Int -> STM a -> IO (Maybe a)
forall a. Int -> STM a -> IO (Maybe a)
atomicallyWithin Int
timeout STM a
m
  IO (Maybe a) -> (Maybe a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO a -> (a -> IO a) -> Maybe a -> IO a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Reconnect -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Reconnect
ReconnectByTimeout) a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure