module Network.MQTT.Broker.Internal where
import Control.Concurrent.MVar
import Control.Concurrent.PrioritySemaphore
import Control.Monad
import qualified Data.Binary as B
import qualified Data.ByteString as BS
import Data.Int
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import qualified Data.Map.Strict as M
import qualified Data.Sequence as Seq
import GHC.Generics (Generic)
import Network.MQTT.Broker.Authentication hiding (getPrincipal)
import qualified Network.MQTT.Broker.RetainedMessages as RM
import qualified Network.MQTT.Broker.SessionStatistics as SS
import Network.MQTT.Message
import qualified Network.MQTT.Trie as R
data Broker auth
= Broker
{ brokerCreatedAt :: Int64
, brokerAuthenticator :: auth
, brokerRetainedStore :: RM.RetainedStore
, brokerState :: MVar (BrokerState auth)
}
data BrokerState auth
= BrokerState
{ brokerMaxSessionIdentifier :: !SessionIdentifier
, brokerSubscriptions :: !(R.Trie IS.IntSet)
, brokerSessions :: !(IM.IntMap (Session auth))
, brokerSessionsByPrincipals :: !(M.Map (PrincipalIdentifier, ClientIdentifier) SessionIdentifier)
}
newtype SessionIdentifier = SessionIdentifier Int deriving (Eq, Ord, Show, Enum, Generic)
data Session auth
= Session
{ sessionBroker :: !(Broker auth)
, sessionIdentifier :: !SessionIdentifier
, sessionClientIdentifier :: !ClientIdentifier
, sessionPrincipalIdentifier :: !PrincipalIdentifier
, sessionCreatedAt :: !Int64
, sessionConnection :: !(MVar Connection)
, sessionPrincipal :: !(MVar Principal)
, sessionSemaphore :: !PrioritySemaphore
, sessionSubscriptions :: !(MVar (R.Trie QoS))
, sessionQueue :: !(MVar ServerQueue)
, sessionQueuePending :: !(MVar ())
, sessionStatistics :: !SS.Statistics
}
data Connection
= Connection
{ connectionCreatedAt :: !Int64
, connectionCleanSession :: !Bool
, connectionSecure :: !Bool
, connectionWebSocket :: !Bool
, connectionRemoteAddress :: !(Maybe BS.ByteString)
} deriving (Eq, Ord, Show, Generic)
data ServerQueue
= ServerQueue
{ queuePids :: !(Seq.Seq PacketIdentifier)
, outputBuffer :: !(Seq.Seq ServerPacket)
, queueQoS0 :: !(Seq.Seq Message)
, queueQoS1 :: !(Seq.Seq Message)
, queueQoS2 :: !(Seq.Seq Message)
, notAcknowledged :: !(IM.IntMap Message)
, notReceived :: !(IM.IntMap Message)
, notReleased :: !(IM.IntMap Message)
, notComplete :: !IS.IntSet
}
instance B.Binary SessionIdentifier
instance B.Binary Connection
instance Eq (Session auth) where
(==) s1 s2 = (==) (sessionIdentifier s1) (sessionIdentifier s2)
instance Ord (Session auth) where
compare s1 s2 = compare (sessionIdentifier s1) (sessionIdentifier s2)
instance Show (Session auth) where
show session =
"Session { identifier = " ++ show (sessionIdentifier session)
++ ", principal = " ++ show (sessionPrincipalIdentifier session)
++ ", client = " ++ show (sessionClientIdentifier session) ++ " }"
publishDownstream :: Broker auth -> Message -> IO ()
publishDownstream broker msg = do
RM.store msg (brokerRetainedStore broker)
let topic = msgTopic msg
st <- readMVar (brokerState broker)
forM_ (IS.elems $ R.lookup topic $ brokerSubscriptions st) $ \key->
case IM.lookup (key :: Int) (brokerSessions st) of
Nothing ->
putStrLn "WARNING: dead session reference"
Just session -> publishMessage session msg
publishUpstream :: Broker auth -> Message -> IO ()
publishUpstream = publishDownstream
notePending :: Session auth -> IO ()
notePending = void . flip tryPutMVar () . sessionQueuePending
waitPending :: Session auth -> IO ()
waitPending = void . readMVar . sessionQueuePending
emptyServerQueue :: Int -> ServerQueue
emptyServerQueue i = ServerQueue
{ queuePids = Seq.fromList $ fmap PacketIdentifier [0 .. min i 65535]
, outputBuffer = mempty
, queueQoS0 = mempty
, queueQoS1 = mempty
, queueQoS2 = mempty
, notAcknowledged = mempty
, notReceived = mempty
, notReleased = mempty
, notComplete = mempty
}
publishMessage :: Session auth -> Message -> IO ()
publishMessage session msg = do
subscriptions <- readMVar (sessionSubscriptions session)
case R.findMaxBounded (msgTopic msg) subscriptions of
Nothing -> pure ()
Just qos -> enqueueMessage session msg { msgQoS = qos }
publishMessages :: Foldable t => Session auth -> t Message -> IO ()
publishMessages session msgs =
forM_ msgs (publishMessage session)
enqueueMessage :: Session auth -> Message -> IO ()
enqueueMessage session msg = do
quota <- principalQuota <$> readMVar (sessionPrincipal session)
success <- modifyMVar (sessionQueue session) $ \queue-> case msgQoS msg of
QoS0 -> if (fromIntegral $ quotaMaxQueueSizeQoS0 quota) > Seq.length (queueQoS0 queue)
then pure $ (, True) $! queue { queueQoS0 = queueQoS0 queue Seq.|> msg }
else pure $ (, True) $! queue { queueQoS0 = Seq.drop 1 $ queueQoS0 queue Seq.|> msg }
QoS1 -> if (fromIntegral $ quotaMaxQueueSizeQoS1 quota) > Seq.length (queueQoS1 queue)
then pure $ (, True) $! queue { queueQoS1 = queueQoS1 queue Seq.|> msg }
else pure (queue, False)
QoS2 -> if (fromIntegral $ quotaMaxQueueSizeQoS2 quota) > Seq.length (queueQoS2 queue)
then pure $ (, True) $! queue { queueQoS2 = queueQoS2 queue Seq.|> msg }
else pure (queue, False)
if success
then notePending session
else terminate session
enqueueMessages :: Foldable t => Session auth -> t Message -> IO ()
enqueueMessages session msgs =
forM_ msgs (enqueueMessage session)
terminate :: Session auth -> IO ()
terminate session =
exclusively (sessionSemaphore session) $
modifyMVarMasked_ (brokerState $ sessionBroker session) $ \st->
withMVarMasked (sessionSubscriptions session) $ \subscriptions->
pure st
{ brokerSessions = IM.delete sid ( brokerSessions st )
, brokerSessionsByPrincipals = M.delete
( sessionPrincipalIdentifier session
, sessionClientIdentifier session)
(brokerSessionsByPrincipals st)
, brokerSubscriptions = R.differenceWith
(\b _-> Just (IS.delete sid b) )
( brokerSubscriptions st ) subscriptions
}
where
SessionIdentifier sid = sessionIdentifier session