{-# LANGUAGE BangPatterns, CPP, OverloadedStrings, ScopedTypeVariables, DeriveDataTypeable #-}
module Network.AMQP.Internal where
import Paths_RabbitMQ(version)
import Data.Version(showVersion)
import Control.Concurrent
import Control.Concurrent.STM
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put as BPut
import Network.Socket (PortNumber, withSocketsDo)
import System.IO (hPutStrLn, stderr)
import qualified Control.Exception as CE
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import qualified Data.Foldable as F
import qualified Data.IntMap as IM
import qualified Data.IntSet as IntSet
import qualified Data.Sequence as Seq
import qualified Data.Text as T
import qualified Data.Text.Encoding as E
import Network.AMQP.Prelude
import Network.AMQP.Protocol
import Network.AMQP.Types
import Network.AMQP.Helpers
import Network.AMQP.Generated
import Network.AMQP.ChannelAllocator
import qualified Network.AMQP.Connection as Conn
data AckType = BasicAck | BasicNack deriving Show
data DeliveryMode = Persistent
| NonPersistent
deriving (Eq, Ord, Read, Show)
deliveryModeToInt :: DeliveryMode -> Octet
deliveryModeToInt NonPersistent = 1
deliveryModeToInt Persistent = 2
intToDeliveryMode :: Octet -> DeliveryMode
intToDeliveryMode 1 = NonPersistent
intToDeliveryMode 2 = Persistent
intToDeliveryMode n = error ("Unknown delivery mode int: " ++ show n)
data Message = Message {
msgBody :: BL.ByteString,
msgDeliveryMode :: Maybe DeliveryMode,
msgTimestamp :: Maybe Timestamp,
msgID :: Maybe Text,
msgType :: Maybe Text,
msgUserID :: Maybe Text,
msgApplicationID :: Maybe Text,
msgClusterID :: Maybe Text,
msgContentType :: Maybe Text,
msgContentEncoding :: Maybe Text,
msgReplyTo :: Maybe Text,
msgPriority :: Maybe Octet,
msgCorrelationID :: Maybe Text,
msgExpiration :: Maybe Text,
msgHeaders :: Maybe FieldTable
}
deriving (Eq, Ord, Read, Show)
data Envelope = Envelope
{
envDeliveryTag :: LongLongInt,
envRedelivered :: Bool,
envExchangeName :: Text,
envRoutingKey :: Text,
envChannel :: Channel
}
data PublishError = PublishError
{
errReplyCode :: ReturnReplyCode,
errExchange :: Maybe Text,
errRoutingKey :: Text
}
deriving (Eq, Read, Show)
data ReturnReplyCode = Unroutable Text
| NoConsumers Text
| NotFound Text
deriving (Eq, Read, Show)
data Assembly = SimpleMethod MethodPayload
| ContentMethod MethodPayload ContentHeaderProperties BL.ByteString
deriving Show
readAssembly :: Chan FramePayload -> IO Assembly
readAssembly chan = do
m <- readChan chan
case m of
MethodPayload p ->
if hasContent m
then do
(props, msg) <- collectContent chan
pure $ ContentMethod p props msg
else do
pure $ SimpleMethod p
x -> error $ "didn't expect frame: " ++ show x
collectContent :: Chan FramePayload -> IO (ContentHeaderProperties, BL.ByteString)
collectContent chan = do
(ContentHeaderPayload _ _ bodySize props) <- readChan chan
content <- collect $ fromIntegral bodySize
pure (props, BL.concat content)
where
collect x | x <= 0 = pure []
collect x = do
(ContentBodyPayload payload) <- readChan chan
r <- collect (x - (BL.length payload))
pure $ payload : r
data Connection = Connection {
connHandle :: Conn.Connection,
connChanAllocator :: ChannelAllocator,
connChannels :: MVar (IM.IntMap (Channel, ThreadId)),
connMaxFrameSize :: Int,
connClosed :: MVar (Maybe (CloseType, String)),
connClosedLock :: MVar (),
connWriteLock :: MVar (),
connClosedHandlers :: MVar [IO ()],
connBlockedHandlers :: MVar [(Text -> IO (), IO ())],
connLastReceived :: MVar Int64,
connLastSent :: MVar Int64,
connServerProperties :: FieldTable
}
data ConnectionOpts = ConnectionOpts {
coServers :: ![(String, PortNumber)],
coVHost :: !Text,
coAuth :: ![SASLMechanism],
coMaxFrameSize :: !(Maybe Word32),
coHeartbeatDelay :: !(Maybe Word16),
coMaxChannel :: !(Maybe Word16),
coTLSSettings :: !(Maybe Conn.SSLContext),
coName :: !(Maybe Text)
}
data SASLMechanism = SASLMechanism {
saslName :: !Text,
saslInitialResponse :: !BS.ByteString,
saslChallengeFunc :: !(Maybe (BS.ByteString -> IO BS.ByteString))
}
connectionReceiver :: Connection -> IO ()
connectionReceiver conn = do
CE.catch (do
Frame chanID payload <- readFrame (connHandle conn)
updateLastReceived conn
forwardToChannel chanID payload
)
(\(e :: CE.IOException) -> myThreadId >>= killConnection conn Abnormal (CE.toException e))
connectionReceiver conn
where
closedByUserEx = ConnectionClosedException Normal "closed by user"
forwardToChannel 0 (MethodPayload Connection_close_ok) = myThreadId >>= killConnection conn Normal (CE.toException closedByUserEx)
forwardToChannel 0 (MethodPayload (Connection_close _ (ShortString errorMsg) _ _)) = do
writeFrame (connHandle conn) $ Frame 0 $ MethodPayload Connection_close_ok
myThreadId >>= killConnection conn Abnormal (CE.toException . ConnectionClosedException Abnormal . T.unpack $ errorMsg)
forwardToChannel 0 HeartbeatPayload = pure ()
forwardToChannel 0 (MethodPayload (Connection_blocked reason)) = handleBlocked reason
forwardToChannel 0 (MethodPayload Connection_unblocked) = handleUnblocked
forwardToChannel 0 payload = hPutStrLn stderr $ "Got unexpected msg on channel zero: " ++ show payload
forwardToChannel chanID payload = do
withMVar (connChannels conn) $ \cs -> do
case IM.lookup (fromIntegral chanID) cs of
Just c -> writeChan (inQueue $ fst c) payload
Nothing -> hPutStrLn stderr $ "ERROR: channel not open " ++ show chanID
handleBlocked (ShortString reason) = do
withMVar (connBlockedHandlers conn) $ \listeners ->
forM_ listeners $ \(l, _) -> CE.catch (l reason) $ \(ex :: CE.SomeException) ->
hPutStrLn stderr $ "connection blocked listener threw exception: "++ show ex
handleUnblocked = do
withMVar (connBlockedHandlers conn) $ \listeners ->
forM_ listeners $ \(_, l) -> CE.catch l $ \(ex :: CE.SomeException) ->
hPutStrLn stderr $ "connection unblocked listener threw exception: "++ show ex
openConnection'' :: ConnectionOpts -> IO Connection
openConnection'' connOpts = withSocketsDo $ do
handle <- connect [] $ coServers connOpts
(maxFrameSize, maxChannel, heartbeatTimeout, serverProps) <- CE.handle (\(_ :: CE.IOException) -> CE.throwIO $ ConnectionClosedException Abnormal "Handshake failed. Please check the RabbitMQ logs for more information") $ do
Conn.connectionPut handle $ BS.append (BC.pack "AMQP")
(BS.pack [
1
, 1
, 0
, 9
])
Frame 0 (MethodPayload (Connection_start _ _ serverProps (LongString serverMechanisms) _)) <- readFrame handle
selectedSASL <- selectSASLMechanism handle serverMechanisms
writeFrame handle $ start_ok selectedSASL
Frame 0 (MethodPayload (Connection_tune channel_max frame_max sendHeartbeat)) <- handleSecureUntilTune handle selectedSASL
let maxFrameSize = chooseMin frame_max $ coMaxFrameSize connOpts
finalHeartbeatSec = fromMaybe sendHeartbeat (coHeartbeatDelay connOpts)
heartbeatTimeout = mfilter (/=0) (Just finalHeartbeatSec)
fixChanNum x = if x == 0 then 65535 else x
maxChannel = chooseMin (fixChanNum channel_max) $ fmap fixChanNum $ coMaxChannel connOpts
writeFrame handle (Frame 0 (MethodPayload
(Connection_tune_ok maxChannel maxFrameSize finalHeartbeatSec)
))
writeFrame handle open
Frame 0 (MethodPayload (Connection_open_ok _)) <- readFrame handle
pure (maxFrameSize, maxChannel, heartbeatTimeout, serverProps)
cChannels <- newMVar IM.empty
cClosed <- newMVar Nothing
cChanAllocator <- newChannelAllocator $ fromIntegral maxChannel
_ <- allocateChannel cChanAllocator
writeLock <- newMVar ()
ccl <- newEmptyMVar
cClosedHandlers <- newMVar []
cBlockedHandlers <- newMVar []
cLastReceived <- getTimestamp >>= newMVar
cLastSent <- getTimestamp >>= newMVar
let conn = Connection handle cChanAllocator cChannels (fromIntegral maxFrameSize) cClosed ccl writeLock cClosedHandlers cBlockedHandlers cLastReceived cLastSent serverProps
connThread <- forkFinally (connectionReceiver conn) $ \res -> do
CE.catch (Conn.connectionClose handle) (\(_ :: CE.SomeException) -> pure ())
modifyMVar_ cClosed $ pure . Just . fromMaybe (Abnormal, "unknown reason")
let finaliser = ChanThreadKilledException $ case res of
Left ex -> ex
Right _ -> CE.toException CE.ThreadKilled
modifyMVar_ cChannels $ \x -> do
mapM_ (flip CE.throwTo finaliser . snd) $ IM.elems x
pure IM.empty
void $ tryPutMVar ccl ()
withMVar cClosedHandlers sequence_
case heartbeatTimeout of
Nothing -> pure ()
Just timeout -> do
heartbeatThread <- watchHeartbeats conn (fromIntegral timeout) connThread
addConnectionClosedHandler conn True (killThread heartbeatThread)
pure conn
where
connect excs ((host, port) : rest) = do
result <- CE.try (Conn.connectTo Conn.ConnectionParams
{ Conn.connectionHostname = host
, Conn.connectionPort = port
, Conn.connectionSecure = coTLSSettings connOpts
})
either
(\(ex :: CE.SomeException) -> do
connect (ex:excs) rest)
(pure)
result
connect excs [] = CE.throwIO $ ConnectionClosedException Abnormal $ "Could not connect to any of the provided brokers: " ++ show (zip (coServers connOpts) (reverse excs))
selectSASLMechanism handle serverMechanisms =
let serverSaslList = T.split (== ' ') $ E.decodeUtf8 serverMechanisms
clientMechanisms = coAuth connOpts
clientSaslList = map saslName clientMechanisms
maybeSasl = F.find (\(SASLMechanism name _ _) -> elem name serverSaslList) clientMechanisms
in abortIfNothing maybeSasl handle
("None of the provided SASL mechanisms "++show clientSaslList++" is supported by the server "++show serverSaslList++".")
start_ok sasl = (Frame 0 (MethodPayload (Connection_start_ok clientProperties
(ShortString $ saslName sasl)
(LongString $ saslInitialResponse sasl)
(ShortString "en_US")) ))
where clientProperties = FieldTable $ M.fromList $ [ ("platform", FVString "Haskell")
, ("version" , FVString . T.pack $ showVersion version)
, ("capabilities", FVFieldTable clientCapabilities)
] ++ maybe [] (\x -> [("connection_name", FVString x)]) (coName connOpts)
clientCapabilities = FieldTable $ M.fromList $ [ ("consumer_cancel_notify", FVBool True),
("connection.blocked", FVBool True) ]
handleSecureUntilTune handle sasl = do
tuneOrSecure <- readFrame handle
case tuneOrSecure of
Frame 0 (MethodPayload (Connection_secure (LongString challenge))) -> do
processChallenge <- abortIfNothing (saslChallengeFunc sasl)
handle $ "The server provided a challenge, but the selected SASL mechanism "++show (saslName sasl)++" is not equipped with a challenge processing function."
challengeResponse <- processChallenge challenge
writeFrame handle (Frame 0 (MethodPayload (Connection_secure_ok (LongString challengeResponse))))
handleSecureUntilTune handle sasl
tune@(Frame 0 (MethodPayload (Connection_tune _ _ _))) -> pure tune
x -> error $ "handleSecureUntilTune fail. received message: "++show x
open = (Frame 0 (MethodPayload (Connection_open
(ShortString $ coVHost connOpts)
(ShortString $ T.pack "")
True)))
abortHandshake handle msg = do
Conn.connectionClose handle
CE.throwIO $ ConnectionClosedException Abnormal msg
abortIfNothing m handle msg = case m of
Nothing -> abortHandshake handle msg
Just a -> pure a
watchHeartbeats :: Connection -> Int -> ThreadId -> IO ThreadId
watchHeartbeats conn timeout connThread = scheduleAtFixedRate rate $ do
checkSendTimeout
checkReceiveTimeout
where
rate = timeout * 1000 * 250
receiveTimeout = (fromIntegral rate) * 4 * 2
sendTimeout = (fromIntegral rate) * 2
skippedBeatEx = ConnectionClosedException Abnormal "killed connection after missing 2 consecutive heartbeats"
checkReceiveTimeout = doCheck (connLastReceived conn) receiveTimeout
(killConnection conn Abnormal (CE.toException skippedBeatEx) connThread)
checkSendTimeout = doCheck (connLastSent conn) sendTimeout
(writeFrame (connHandle conn) (Frame 0 HeartbeatPayload))
doCheck var timeout_µs action = withMVar var $ \lastFrameTime -> do
time <- getTimestamp
when (time >= lastFrameTime + timeout_µs) $ do
action
updateLastSent :: Connection -> IO ()
updateLastSent conn = modifyMVar_ (connLastSent conn) (const getTimestamp)
updateLastReceived :: Connection -> IO ()
updateLastReceived conn = modifyMVar_ (connLastReceived conn) (const getTimestamp)
killConnection :: Connection -> CloseType -> CE.SomeException -> ThreadId -> IO ()
killConnection conn closeType ex connThread = do
modifyMVar_ (connClosed conn) $ const $ pure $ Just (closeType, show ex)
throwTo connThread ex
closeConnection :: Connection -> IO ()
closeConnection c = do
CE.catch (
withMVar (connWriteLock c) $ \_ -> writeFrame (connHandle c) $ (Frame 0 (MethodPayload (Connection_close
0
(ShortString "")
0
0
)))
)
(\ (_ :: CE.IOException) ->
pure ()
)
readMVar $ connClosedLock c
pure ()
getServerProperties :: Connection -> IO FieldTable
getServerProperties = pure . connServerProperties
addConnectionClosedHandler :: Connection -> Bool -> IO () -> IO ()
addConnectionClosedHandler conn ifClosed handler = do
withMVar (connClosed conn) $ \cc ->
case cc of
Just _ | ifClosed == True -> handler
_ -> modifyMVar_ (connClosedHandlers conn) $ \old -> pure $ handler:old
addConnectionBlockedHandler :: Connection -> (Text -> IO ()) -> IO () -> IO ()
addConnectionBlockedHandler conn blockedHandler unblockedHandler =
modifyMVar_ (connBlockedHandlers conn) $ \old -> pure $ (blockedHandler, unblockedHandler):old
readFrame :: Conn.Connection -> IO Frame
readFrame handle = do
strictDat <- connectionGetExact handle 7
let dat = toLazy strictDat
when (BL.null dat) $ CE.throwIO $ userError "connection not open"
let len = fromIntegral $ peekFrameSize dat
strictDat' <- connectionGetExact handle (len+1)
let dat' = toLazy strictDat'
when (BL.null dat') $ CE.throwIO $ userError "connection not open"
#if MIN_VERSION_binary(0, 7, 0)
let ret = runGetOrFail get (BL.append dat dat')
case ret of
Left (_, _, errMsg) -> error $ "readFrame fail: " ++ errMsg
Right (_, consumedBytes, _) | consumedBytes /= fromIntegral (len+8) ->
error $ "readFrame: parser should read " ++ show (len+8) ++ " bytes; but read " ++ show consumedBytes
Right (_, _, frame) -> pure frame
#else
let (frame, _, consumedBytes) = runGetState get (BL.append dat dat') 0
if consumedBytes /= fromIntegral (len+8)
then error $ "readFrameSock: parser should read "++show (len+8)++" bytes; but read "++show consumedBytes
else pure ()
pure frame
#endif
connectionGetExact :: Conn.Connection -> Int -> IO BS.ByteString
connectionGetExact conn x = loop BS.empty 0
where loop bs y
| y == x = pure bs
| otherwise = do
next <- Conn.connectionGet conn (x - y)
loop (BS.append bs next) (y + (BS.length next))
writeFrame :: Conn.Connection -> Frame -> IO ()
writeFrame handle f = do
Conn.connectionPut handle . toStrict . runPut . put $ f
data Channel = Channel {
connection :: Connection,
inQueue :: Chan FramePayload,
outstandingResponses :: MVar (Seq.Seq (MVar Assembly)),
channelID :: Word16,
lastConsumerTag :: MVar Int,
nextPublishSeqNum :: MVar Int,
unconfirmedSet :: TVar IntSet.IntSet,
ackedSet :: TVar IntSet.IntSet,
nackedSet :: TVar IntSet.IntSet,
chanActive :: Lock,
chanClosed :: MVar (Maybe (CloseType, String)),
consumers :: MVar (M.Map Text ((Message, Envelope) -> IO (),
(ConsumerTag -> IO ()))),
returnListeners :: MVar ([(Message, PublishError) -> IO ()]),
confirmListeners :: MVar ([(Word64, Bool, AckType) -> IO ()]),
chanExceptionHandlers :: MVar [CE.SomeException -> IO ()]
}
data ChanThreadKilledException = ChanThreadKilledException { cause :: CE.SomeException }
deriving (Show,Typeable)
instance CE.Exception ChanThreadKilledException
unwrapChanThreadKilledException :: CE.SomeException -> CE.SomeException
unwrapChanThreadKilledException e = maybe e cause $ CE.fromException e
msgFromContentHeaderProperties :: ContentHeaderProperties -> BL.ByteString -> Message
msgFromContentHeaderProperties (CHBasic content_type content_encoding headers delivery_mode priority correlation_id reply_to expiration message_id timestamp message_type user_id application_id cluster_id) body =
let msgId = fromShortString message_id
contentType = fromShortString content_type
contentEncoding = fromShortString content_encoding
replyTo = fromShortString reply_to
correlationID = fromShortString correlation_id
messageType = fromShortString message_type
userId = fromShortString user_id
applicationId = fromShortString application_id
clusterId = fromShortString cluster_id
in Message body (fmap intToDeliveryMode delivery_mode) timestamp msgId messageType userId applicationId clusterId contentType contentEncoding replyTo priority correlationID (fromShortString expiration) headers
where
fromShortString (Just (ShortString s)) = Just s
fromShortString _ = Nothing
msgFromContentHeaderProperties c _ = error ("Unknown content header properties: " ++ show c)
channelReceiver :: Channel -> IO ()
channelReceiver chan = do
p <- readAssembly $ inQueue chan
if isResponse p
then do
action <- modifyMVar (outstandingResponses chan) $ \val -> do
case Seq.viewl val of
x Seq.:< rest -> do
pure (rest, putMVar x p)
Seq.EmptyL -> do
pure (val, CE.throwIO $ userError "got response, but have no corresponding request")
action
else handleAsync p
channelReceiver chan
where
isResponse :: Assembly -> Bool
isResponse (ContentMethod (Basic_deliver _ _ _ _ _) _ _) = False
isResponse (ContentMethod (Basic_return _ _ _ _) _ _) = False
isResponse (SimpleMethod (Channel_flow _)) = False
isResponse (SimpleMethod (Channel_close _ _ _ _)) = False
isResponse (SimpleMethod (Basic_ack _ _)) = False
isResponse (SimpleMethod (Basic_nack _ _ _)) = False
isResponse (SimpleMethod (Basic_cancel _ _)) = False
isResponse _ = True
handleAsync (ContentMethod (Basic_deliver (ShortString consumerTag) deliveryTag redelivered (ShortString exchange)
(ShortString routingKey))
properties body) =
withMVar (consumers chan) (\s -> do
case M.lookup consumerTag s of
Just (subscriber, _) -> do
let msg = msgFromContentHeaderProperties properties body
let env = Envelope {envDeliveryTag = deliveryTag, envRedelivered = redelivered,
envExchangeName = exchange, envRoutingKey = routingKey, envChannel = chan}
CE.catches (subscriber (msg, env))
[
CE.Handler (\(e::ChanThreadKilledException) -> CE.throwIO $ cause e),
CE.Handler (\(e::CE.SomeException) -> hPutStrLn stderr $ "AMQP callback threw exception: " ++ show e)
]
Nothing ->
pure ()
)
handleAsync (SimpleMethod (Channel_close _ (ShortString errorMsg) _ _)) = do
CE.catch (
writeAssembly' chan (SimpleMethod (Channel_close_ok))
)
(\ (_ :: CE.IOException) ->
pure ()
)
closeChannel' chan Abnormal errorMsg
myThreadId >>= flip CE.throwTo (ChannelClosedException Abnormal . T.unpack $ errorMsg)
handleAsync (SimpleMethod (Channel_flow active)) = do
if active
then openLock $ chanActive chan
else closeLock $ chanActive chan
pure ()
handleAsync (ContentMethod basicReturn@(Basic_return _ _ _ _) props body) = do
let msg = msgFromContentHeaderProperties props body
pubError = basicReturnToPublishError basicReturn
withMVar (returnListeners chan) $ \listeners ->
forM_ listeners $ \l -> CE.catch (l (msg, pubError)) $ \(ex :: CE.SomeException) ->
hPutStrLn stderr $ "return listener on channel ["++(show $ channelID chan)++"] handling error ["++show pubError++"] threw exception: "++show ex
handleAsync (SimpleMethod (Basic_ack deliveryTag multiple)) = handleConfirm deliveryTag multiple BasicAck
handleAsync (SimpleMethod (Basic_nack deliveryTag multiple _)) = handleConfirm deliveryTag multiple BasicNack
handleAsync (SimpleMethod (Basic_cancel consumerTag _)) = handleCancel consumerTag
handleAsync m = error ("Unknown method: " ++ show m)
handleConfirm deliveryTag multiple k = do
withMVar (confirmListeners chan) $ \listeners ->
forM_ listeners $ \l -> CE.catch (l (deliveryTag, multiple, k)) $ \(ex :: CE.SomeException) ->
hPutStrLn stderr $ "confirm listener on channel ["++(show $ channelID chan)++"] handling method "++(show k)++" threw exception: "++ show ex
let seqNum = fromIntegral deliveryTag
let targetSet = case k of
BasicAck -> (ackedSet chan)
BasicNack -> (nackedSet chan)
atomically $ do
unconfSet <- readTVar (unconfirmedSet chan)
let (merge, pending) = if multiple
then (IntSet.union confs, pending')
else (IntSet.insert seqNum, IntSet.delete seqNum unconfSet)
where
confs = fst parts
pending' = snd parts
parts = IntSet.partition (\n -> n <= seqNum) unconfSet
modifyTVar' targetSet (\ts -> merge ts)
writeTVar (unconfirmedSet chan) pending
handleCancel (ShortString consumerTag) =
withMVar (consumers chan) (\s -> do
case M.lookup consumerTag s of
Just (_, cancelCB) ->
CE.catch (cancelCB consumerTag) $ \(ex :: CE.SomeException) ->
hPutStrLn stderr $ "consumer cancellation listener "++(show consumerTag)++" on channel ["++(show $ channelID chan)++"] threw exception: "++ show ex
Nothing ->
pure ()
)
basicReturnToPublishError (Basic_return code (ShortString errText) (ShortString exchange) (ShortString routingKey)) =
let replyError = case code of
312 -> Unroutable errText
313 -> NoConsumers errText
404 -> NotFound errText
num -> error $ "unexpected return error code: " ++ (show num)
pubError = PublishError replyError (Just exchange) routingKey
in pubError
basicReturnToPublishError x = error $ "basicReturnToPublishError fail: "++show x
addReturnListener :: Channel -> ((Message, PublishError) -> IO ()) -> IO ()
addReturnListener chan listener = do
modifyMVar_ (returnListeners chan) $ \listeners -> pure $ listener:listeners
addChannelExceptionHandler :: Channel -> (CE.SomeException -> IO ()) -> IO ()
addChannelExceptionHandler chan handler = do
modifyMVar_ (chanExceptionHandlers chan) $ \handlers -> pure $ handler:handlers
closeChannel' :: Channel -> CloseType -> Text -> IO ()
closeChannel' c closeType reason = do
modifyMVar_ (chanClosed c) $ \x -> do
if isNothing x
then do
modifyMVar_ (connChannels $ connection c) $ \old -> do
ret <- freeChannel (connChanAllocator $ connection c) $ fromIntegral $ channelID c
when (not ret) $ hPutStrLn stderr "closeChannel error: channel already freed"
pure $ IM.delete (fromIntegral $ channelID c) old
void $ killLock $ chanActive c
killOutstandingResponses $ outstandingResponses c
pure $ Just (closeType, T.unpack reason)
else pure x
where
killOutstandingResponses :: MVar (Seq.Seq (MVar a)) -> IO ()
killOutstandingResponses outResps = do
modifyMVar_ outResps $ \val -> do
F.mapM_ (\x -> tryPutMVar x $ error "channel closed") val
pure undefined
openChannel :: Connection -> IO Channel
openChannel c = do
newInQueue <- newChan
outRes <- newMVar Seq.empty
lastConsTag <- newMVar 0
ca <- newLock
closed <- newMVar Nothing
conss <- newMVar M.empty
retListeners <- newMVar []
aSet <- newTVarIO IntSet.empty
nSet <- newTVarIO IntSet.empty
nxtSeq <- newMVar 0
unconfSet <- newTVarIO IntSet.empty
cnfListeners <- newMVar []
handlers <- newMVar []
newChannel <- modifyMVar (connChannels c) $ \mp -> do
newChannelID <- allocateChannel (connChanAllocator c)
let newChannel = Channel c newInQueue outRes (fromIntegral newChannelID) lastConsTag nxtSeq unconfSet aSet nSet ca closed conss retListeners cnfListeners handlers
thrID <- forkFinally (channelReceiver newChannel) $ \res -> do
closeChannel' newChannel Normal "closed"
case res of
Right _ -> pure ()
Left ex -> do
let unwrappedExc = unwrapChanThreadKilledException ex
handlers' <- readMVar handlers
case (null handlers', fromAbnormalChannelClose unwrappedExc) of
(True, Just reason) -> hPutStrLn stderr $ "unhandled AMQP channel exception (chanId="++show newChannelID++"): "++reason
_ -> mapM_ ($ unwrappedExc) handlers'
when (IM.member newChannelID mp) $ CE.throwIO $ userError "openChannel fail: channel already open"
pure (IM.insert newChannelID (newChannel, thrID) mp, newChannel)
SimpleMethod (Channel_open_ok _) <- request newChannel (SimpleMethod (Channel_open (ShortString "")))
pure newChannel
where
fromAbnormalChannelClose :: CE.SomeException -> Maybe String
fromAbnormalChannelClose exc =
case CE.fromException exc :: Maybe AMQPException of
Just (ChannelClosedException Abnormal reason) -> Just reason
_ -> Nothing
closeChannel :: Channel -> IO ()
closeChannel c = do
SimpleMethod Channel_close_ok <- request c $ SimpleMethod $ Channel_close 0 (ShortString "") 0 0
withMVar (connChannels $ connection c) $ \chans -> do
case IM.lookup (fromIntegral $ channelID c) chans of
Just (_, thrID) -> killThread thrID
Nothing -> pure ()
writeFrames :: Channel -> [FramePayload] -> IO ()
writeFrames chan payloads =
let conn = connection chan in
withMVar (connChannels conn) $ \chans ->
if IM.member (fromIntegral $ channelID chan) chans
then
CE.catch
(do
withMVar (connWriteLock conn) $ \_ ->
mapM_ (\payload -> writeFrame (connHandle conn) (Frame (channelID chan) payload)) payloads
updateLastSent conn)
( \(_ :: CE.IOException) -> do
CE.throwIO $ userError "connection not open"
)
else do
CE.throwIO $ userError "channel not open"
writeAssembly' :: Channel -> Assembly -> IO ()
writeAssembly' chan (ContentMethod m properties msg) = do
waitLock $ chanActive chan
let !toWrite =
[(MethodPayload m),
(ContentHeaderPayload
(getClassIDOf properties)
0
(fromIntegral $ BL.length msg)
properties)] ++
(if BL.length msg > 0
then do
map ContentBodyPayload
(splitLen msg $ (fromIntegral $ connMaxFrameSize $ connection chan) - 8)
else []
)
writeFrames chan toWrite
where
splitLen str len | BL.length str > len = (BL.take len str):(splitLen (BL.drop len str) len)
splitLen str _ = [str]
writeAssembly' chan (SimpleMethod m) = writeFrames chan [MethodPayload m]
writeAssembly :: Channel -> Assembly -> IO ()
writeAssembly chan m =
CE.catches
(writeAssembly' chan m)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
request :: Channel -> Assembly -> IO Assembly
request chan m = do
res <- newEmptyMVar
CE.catches (do
withMVar (chanClosed chan) $ \cc -> do
if isNothing cc
then do
modifyMVar_ (outstandingResponses chan) $ \val -> pure $! val Seq.|> res
writeAssembly' chan m
else CE.throwIO $ userError "closed"
!r <- takeMVar res
pure r
)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
throwMostRelevantAMQPException :: Channel -> IO a
throwMostRelevantAMQPException chan = do
cc <- readMVar $ connClosed $ connection chan
case cc of
Just (closeType, r) -> CE.throwIO $ ConnectionClosedException closeType r
Nothing -> do
chc <- readMVar $ chanClosed chan
case chc of
Just (ct, r) -> CE.throwIO $ ChannelClosedException ct r
Nothing -> CE.throwIO $ ConnectionClosedException Abnormal "unknown reason"
waitForAllConfirms :: Channel -> STM (IntSet.IntSet, IntSet.IntSet)
waitForAllConfirms chan = do
pending <- readTVar $ (unconfirmedSet chan)
check (IntSet.null pending)
pure =<< (,)
<$> swapTVar (ackedSet chan) IntSet.empty
<*> swapTVar (nackedSet chan) IntSet.empty