module Network.Nats.Client (
defaultConnectionSettings
, defaultNatsHost
, defaultNatsPort
, connect
, publish
, subscribe
, unsubscribe
, withNats
, createSubject
, ConnectionSettings(..)
, MessageHandler
, NatsClient
, Subject
, Message (..)
) where
import Control.Concurrent (forkFinally)
import Control.Concurrent.MVar
import Control.Exception hiding (catch, bracket)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Trans.Control (MonadBaseControl)
import Data.IORef
import Data.Pool
import Data.Typeable
import Network
import Network.Nats.Protocol
import System.IO (Handle, BufferMode(LineBuffering), hClose, hSetBuffering)
import System.Log.Logger
import System.Random
import System.Timeout
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.ByteString.Char8 as BS
data NatsClient = NatsClient { connections :: Pool NatsServerConnection
, settings :: ConnectionSettings
, subscriptions :: MVar (M.Map SubscriptionId NatsServerConnection)
, servers :: MVar (S.Set (HostName, PortNumber))
}
data NatsServerConnection = NatsServerConnection { natsHandle :: Handle
, natsInfo :: NatsServerInfo
, maxMessages :: IORef (Maybe Int)
}
data ConnectionSettings = ConnectionSettings { host :: HostName
, port :: PortNumber
} deriving (Show)
type MessageHandler = (Message -> IO ())
data NatsError = ConnectionFailure
| ConnectionTimeout
| InvalidServerBanner String
| PayloadTooLarge String
deriving (Show, Typeable)
instance Exception NatsError
defaultConnectionSettings :: ConnectionSettings
defaultConnectionSettings = ConnectionSettings defaultNatsHost defaultNatsPort
defaultNatsHost :: HostName
defaultNatsHost = "127.0.0.1"
defaultNatsPort :: PortNumber
defaultNatsPort = 4222 :: PortNumber
makeNatsServerConnection :: (MonadThrow m, Connection m) => ConnectionSettings -> MVar (S.Set (HostName, PortNumber)) -> m NatsServerConnection
makeNatsServerConnection (ConnectionSettings ho po) srvs = do
mh <- liftIO $ timeout defaultTimeout $ connectTo ho $ PortNumber po
case mh of
Nothing -> do
liftIO $ warningM "Network.Nats.Client" $ "Timed out connecting to server: " ++ (show ho) ++ ":" ++ (show po)
throwM ConnectionTimeout
Just h -> do
liftIO $ hSetBuffering h LineBuffering
nInfo <- liftIO $ receiveServerBanner h
liftIO $ infoM "Network.Nats.Client" $ "Received server info " ++ show nInfo
case nInfo of
Right info -> do
sendConnect h defaultConnectionOptions
maxMsgs <- liftIO $ newIORef Nothing
_subsc <- liftIO $ newIORef Nothing
liftIO $ modifyMVarMasked_ srvs $ \ss -> return $ S.insert (ho, po) ss
return $ NatsServerConnection { natsHandle = h, natsInfo = info, maxMessages = maxMsgs }
Left err -> throwM $ InvalidServerBanner err
destroyNatsServerConnection :: Connection m => NatsServerConnection -> m ()
destroyNatsServerConnection conn = liftIO $ hClose (natsHandle conn)
connect :: (MonadThrow m, MonadIO m) => ConnectionSettings -> Int -> m NatsClient
connect s max_connections = do
srvs <- liftIO $ newMVar S.empty
connpool <- liftIO $ createPool (makeNatsServerConnection s srvs) destroyNatsServerConnection 1 300 max_connections
subs <- liftIO $ newMVar M.empty
return $ NatsClient { connections = connpool, settings = s, subscriptions = subs, servers = srvs }
disconnect :: (MonadIO m) => NatsClient -> m ()
disconnect conn = liftIO $ destroyAllResources (connections conn)
withNats :: (MonadMask m, MonadIO m) => ConnectionSettings -> (NatsClient -> m b) -> m b
withNats connectionSettings f = bracket (connect connectionSettings 10) disconnect f
publish :: (MonadThrow m, MonadIO m, MonadBaseControl IO m) => NatsClient -> Subject -> BS.ByteString -> m ()
publish conn subj msg = withResource (connections conn) $ doPublish subj msg
doPublish :: (MonadThrow m, MonadIO m) => Subject -> BS.ByteString -> NatsServerConnection -> m ()
doPublish subj msg conn = do
case payload_length > (maxPayloadSize (natsInfo conn)) of
True -> throwM $ PayloadTooLarge $ "Size: " ++ show payload_length ++ ", Max: " ++ show (maxPayloadSize (natsInfo conn))
False -> liftIO $ sendPub sock subj payload_length msg
where sock = (natsHandle conn)
payload_length = BS.length msg
subscribe :: MonadIO m => NatsClient -> Subject -> MessageHandler -> Maybe QueueGroup -> m SubscriptionId
subscribe conn subj callback _qgroup = do
(c, pool) <- liftIO $ takeResource (connections conn)
subId <- liftIO $ generateSubscriptionId 5
let sock = (natsHandle c)
max_payload = (maxPayloadSize (natsInfo c))
maxMsgs = (maxMessages c)
liftIO $ sendSub sock subj subId Nothing
_ <- liftIO $ forkFinally (connectionLoop sock max_payload callback maxMsgs) $ handleCompletion sock (connections conn) pool c
liftIO $ modifyMVarMasked_ (subscriptions conn) $ \m -> return $ M.insert subId c m
return subId
unsubscribe :: (MonadIO m, MonadBaseControl IO m) => NatsClient -> SubscriptionId -> Maybe Int -> m ()
unsubscribe conn subId msgs@(Just maxMsgs) = do
liftIO $ withMVarMasked (subscriptions conn) $ \m -> doUnsubscribe m subId maxMsgs
withResource (connections conn) $ \s -> liftIO $ sendUnsub (natsHandle s) subId msgs
unsubscribe conn subId msgs@Nothing = do
liftIO $ withMVarMasked (subscriptions conn) $ \m -> doUnsubscribe m subId 0
withResource (connections conn) $ \s -> liftIO $ sendUnsub (natsHandle s) subId msgs
doUnsubscribe :: M.Map SubscriptionId NatsServerConnection -> SubscriptionId -> Int -> IO ()
doUnsubscribe m subId maxMsgs = do
case M.lookup subId m of
Nothing ->
warningM "Network.Nats.Client" $ "Could not find subscription " ++ (show subId)
Just c -> do
atomicWriteIORef (maxMessages c) (Just maxMsgs)
handleCompletion :: Handle -> Pool NatsServerConnection -> LocalPool NatsServerConnection -> NatsServerConnection -> Either SomeException b -> IO ()
handleCompletion h pool lpool conn (Left exn) = do
warningM "Network.Nats.Client" $ "Connection closed: " ++ (show exn)
hClose h
destroyResource pool lpool conn
handleCompletion _ _ lpool conn _ = do
debugM "Network.Nats.Client" "Subscription finished"
atomicWriteIORef (maxMessages conn) Nothing
putResource lpool conn
connectionLoop :: Handle -> Int -> (Message -> IO ()) -> IORef (Maybe Int) -> IO ()
connectionLoop h max_payload f maxMsgsRef = do
maxMsgs <- readIORef maxMsgsRef
case maxMsgs of
Just 0 -> return ()
_ ->
receiveMessage h max_payload >>= (\m -> handleMessage h f m maxMsgs) >>= atomicWriteIORef maxMsgsRef >> connectionLoop h max_payload f maxMsgsRef
createSubject :: BS.ByteString -> Either String Subject
createSubject = parseSubject
generateSubscriptionId :: Int -> IO SubscriptionId
generateSubscriptionId idLength = do
gen <- getStdGen
return $ SubscriptionId $ BS.pack $ take idLength $ (randoms gen :: [Char])
handleMessage :: Handle -> (Message -> IO ()) -> Message -> Maybe Int -> IO (Maybe Int)
handleMessage h _ Ping m = do
sendPong h
return m
handleMessage _ f msg@(Message _) maxMsgs = do
f msg
case maxMsgs of
Nothing -> return Nothing
(Just n) -> return $ Just (n 1)
handleMessage _ _ msg m = do
warningM "Network.Nats.Client" $ "Received " ++ (show msg)
return m