{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE MultiWayIf        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Network.MQTT.Broker.Session
-- Copyright   :  (c) Lars Petersen 2016
-- License     :  MIT
--
-- Maintainer  :  info@lars-petersen.net
-- Stability   :  experimental
--------------------------------------------------------------------------------
module Network.MQTT.Broker.Session
  ( publish
  , subscribe
  , unsubscribe
  , disconnect
  , terminate

  , getSubscriptions
  , getConnection
  , getPrincipal
  , getFreePacketIdentifiers

    -- TODO: private
  , reset
  , notePending
  , waitPending
  , publishMessage
  , publishMessages
  , enqueuePingResponse
  , enqueueMessage
  , enqueueSubscribeAcknowledged
  , enqueueUnsubscribeAcknowledged
  , dequeue
  , processPublish
  , processPublishRelease
  , processPublishReceived
  , processPublishComplete
  , processPublishAcknowledged
  , Session (..)
  , SessionIdentifier (..)
  , Connection (..)
  ) where

import           Control.Concurrent.MVar
import           Control.Concurrent.PrioritySemaphore
import           Control.Monad
import           Data.Bool
import           Data.Functor.Identity
import qualified Data.IntMap                           as IM
import qualified Data.IntSet                           as IS
import           Data.Maybe
import           Data.Monoid
import qualified Data.Sequence                         as Seq

import           Network.MQTT.Broker.Authentication    hiding (getPrincipal)
import           Network.MQTT.Broker.Internal
import qualified Network.MQTT.Broker.RetainedMessages  as RM
import qualified Network.MQTT.Broker.SessionStatistics as SS
import           Network.MQTT.Message
import qualified Network.MQTT.Trie                     as R

publish :: Session auth -> Message -> IO ()
publish session msg = do
  principal <- readMVar (sessionPrincipal session)
  -- A topic is permitted if it yields a match in the publish permission tree.
  if R.matchTopic (msgTopic msg) (principalPublishPermissions principal)
    then do
      if retain && R.matchTopic (msgTopic msg) (principalRetainPermissions principal)
        then do
          RM.store msg (brokerRetainedStore $ sessionBroker session)
          SS.accountRetentionsAccepted stats 1
        else
          SS.accountRetentionsDropped stats 1
      publishUpstream (sessionBroker session) msg
      SS.accountPublicationsAccepted stats 1
    else
      SS.accountPublicationsDropped stats 1
  where
    stats = sessionStatistics session
    Retain retain = msgRetain msg

subscribe :: Session auth -> PacketIdentifier -> [(Filter, QoS)] -> IO ()
subscribe session pid filters = do
  principal <- readMVar (sessionPrincipal session)
  checkedFilters <- mapM (checkPermission principal) filters
  let subscribeFilters = mapMaybe (\(filtr,mqos)->(filtr,) <$> mqos) checkedFilters
      qosTree = R.insertFoldable subscribeFilters R.empty
      sidTree = R.map (const $ IS.singleton sid) qosTree
  -- Do the accounting for the session statistics.
  -- TODO: Do this as a transaction below.
  let countAccepted = length subscribeFilters
  let countDenied   = length filters - countAccepted
  SS.accountSubscriptionsAccepted (sessionStatistics session) $ fromIntegral countAccepted
  SS.accountSubscriptionsDenied   (sessionStatistics session) $ fromIntegral countDenied
  -- Force the `qosTree` in order to lock the broker as little as possible.
  -- The `sidTree` is left lazy.
  qosTree `seq` do
    modifyMVarMasked_ (brokerState $ sessionBroker session) $ \bst-> do
      modifyMVarMasked_
        ( sessionSubscriptions session )
        ( pure . R.unionWith max qosTree )
      pure $ bst { brokerSubscriptions = R.unionWith IS.union (brokerSubscriptions bst) sidTree }
    enqueueSubscribeAcknowledged session pid (fmap snd checkedFilters)
    forM_ checkedFilters $ \(filtr,_qos)->
      publishMessages session =<< RM.retrieve filtr (brokerRetainedStore $ sessionBroker session)
  where
    SessionIdentifier sid = sessionIdentifier session
    checkPermission principal (filtr, qos) = do
      let isPermitted = R.matchFilter filtr (principalSubscribePermissions principal)
      pure (filtr, if isPermitted then Just qos else Nothing)

unsubscribe :: Session auth -> PacketIdentifier -> [Filter] -> IO ()
unsubscribe session pid filters =
  -- Force the `unsubBrokerTree` first in order to lock the broker as little as possible.
  unsubBrokerTree `seq` do
    modifyMVarMasked_ (brokerState $ sessionBroker session) $ \bst-> do
      modifyMVarMasked_
        ( sessionSubscriptions session )
        ( pure . flip (R.differenceWith (const . const Nothing)) unsubBrokerTree )
      pure $ bst { brokerSubscriptions = R.differenceWith
                    (\is (Identity i)-> Just (IS.delete i is))
                    (brokerSubscriptions bst) unsubBrokerTree }
    enqueueUnsubscribeAcknowledged session pid
  where
    SessionIdentifier sid = sessionIdentifier session
    unsubBrokerTree  = R.insertFoldable ( fmap (,Identity sid) filters ) R.empty

-- | Disconnect a session.
disconnect :: Session auth -> IO ()
disconnect session =
  -- This assures that the client gets disconnected by interrupting
  -- the current client handler thread (if any).
  exclusively (sessionSemaphore session) (pure ())

-- | Reset the session state after a reconnect.
--
--   * All output buffers will be cleared.
--   * Output buffers will be filled with retransmissions.
reset :: Session auth -> IO ()
reset session =
  modifyMVar_ (sessionQueue session) (\q-> pure $! resetQueue q)

-- | Enqueue a PINGRESP to be sent as soon as the output thread is available.
--
--   The PINGRESP will be inserted with highest priority in front of all other enqueued messages.
enqueuePingResponse :: Session auth -> IO ()
enqueuePingResponse session = do
  modifyMVar_ (sessionQueue session) $ \queue->
    pure $! queue { outputBuffer = ServerPingResponse Seq.<| outputBuffer queue }
  -- IMPORTANT: Notify the sending thread that something has been enqueued!
  notePending session

enqueueSubscribeAcknowledged :: Session auth -> PacketIdentifier -> [Maybe QoS] -> IO ()
enqueueSubscribeAcknowledged session pid mqoss = do
  modifyMVar_ (sessionQueue session) $ \queue->
    pure $! queue { outputBuffer = outputBuffer queue Seq.|> ServerSubscribeAcknowledged pid mqoss }
  notePending session

enqueueUnsubscribeAcknowledged :: Session auth -> PacketIdentifier -> IO ()
enqueueUnsubscribeAcknowledged session pid = do
  modifyMVar_ (sessionQueue session) $ \queue->
    pure $! queue { outputBuffer = outputBuffer queue Seq.|> ServerUnsubscribeAcknowledged pid}
  notePending session

-- | Blocks until messages are available and prefers non-qos0 messages over
--  qos0 messages.
dequeue :: Session auth -> IO (Seq.Seq ServerPacket)
dequeue session =
  modifyMVar (sessionQueue session) $ \queue-> do
    let q = normalizeQueue queue
    if | not (Seq.null $ outputBuffer q) -> pure (q { outputBuffer = mempty }, outputBuffer q)
       | not (Seq.null $ queueQoS0    q) -> pure (q { queueQoS0    = mempty }, fmap (ServerPublish (PacketIdentifier (-1)) (Duplicate False)) (queueQoS0 q))
       | otherwise                       -> clearPending >> pure (q, mempty)
  where
    -- | In case all queues are empty, we need to clear the `pending` variable.
    -- ATTENTION: An implementation error would cause the `dequeue` operation
    -- to rush right through the blocking call leading to enourmous CPU usage.
    clearPending :: IO ()
    clearPending  = void $ tryTakeMVar (sessionQueuePending session)

-- | Process a @PUB@ message received from the peer.
--
--   Different handling depending on message qos.
processPublish :: Session auth -> PacketIdentifier -> Duplicate -> Message -> IO ()
processPublish session pid@(PacketIdentifier p) _dup msg =
  case msgQoS msg of
    QoS0 ->
      publish session msg
    QoS1 -> do
      publish session msg
      modifyMVar_ (sessionQueue session) $ \q-> pure $! q {
          outputBuffer = outputBuffer q Seq.|> ServerPublishAcknowledged pid
        }
      notePending session
    QoS2 -> do
      modifyMVar_ (sessionQueue session) $ \q-> pure $! q {
          outputBuffer = outputBuffer q Seq.|> ServerPublishReceived pid
        , notReleased  = IM.insert p msg (notReleased q)
        }
      notePending session

-- | Note that a QoS1 message has been received by the peer.
--
--   This shall be called when a @PUBACK@ has been received by the peer.
--   We release the message from our buffer and the transaction is then complete.
processPublishAcknowledged :: Session auth -> PacketIdentifier -> IO ()
processPublishAcknowledged session (PacketIdentifier pid) = do
  modifyMVar_ (sessionQueue session) $ \q-> pure $! q {
      -- The packet identifier is free for reuse only if it actually was in the set of notAcknowledged messages.
      queuePids = bool (queuePids q) (PacketIdentifier pid Seq.<| queuePids q) (IM.member pid (notAcknowledged q))
    , notAcknowledged = IM.delete pid (notAcknowledged q)
    }
  -- See code of `processPublishComplete` for explanation.
  notePending session

-- | Note that a QoS2 message has been received by the peer.
--
--   This shall be called when a @PUBREC@ has been received from the peer.
--   This is the completion of the second step in a 4-way handshake.
--   The state changes from _not received_ to _not completed_.
--   We will send a @PUBREL@ to the client and expect a @PUBCOMP@ in return.
processPublishReceived :: Session auth -> PacketIdentifier -> IO ()
processPublishReceived session (PacketIdentifier pid) = do
  modifyMVar_ (sessionQueue session) $ \q->
    pure $! q {
      notReceived  = IM.delete pid (notReceived q)
    , notComplete  = IS.insert pid (notComplete q)
    , outputBuffer = outputBuffer q Seq.|> ServerPublishRelease (PacketIdentifier pid)
    }
  notePending session

-- | Release a `QoS2` message.
--
--   This shall be called when @PUBREL@ has been received from the peer.
--   It enqueues an outgoing @PUBCOMP@.
--   The message is only released if the handler returned without exception.
--   The handler is only executed if there still is a message (it is a valid scenario
--   that it might already have been released).
processPublishRelease :: Session auth -> PacketIdentifier -> IO ()
processPublishRelease session (PacketIdentifier pid) = do
  modifyMVar_ (sessionQueue session) $ \q->
    case IM.lookup pid (notReleased q) of
      Nothing  ->
        pure q
      Just msg -> do
        publish session msg
        pure $! q { notReleased  = IM.delete pid (notReleased q)
                  , outputBuffer = outputBuffer q Seq.|> ServerPublishComplete (PacketIdentifier pid)
                  }
  notePending session

-- | Complete the transmission of a QoS2 message.
--
--   This shall be called when a @PUBCOMP@ has been received from the peer
--   to finally free the packet identifier.
processPublishComplete :: Session auth -> PacketIdentifier -> IO ()
processPublishComplete session (PacketIdentifier pid) = do
  modifyMVar_ (sessionQueue session) $ \q-> pure $! q {
      -- The packet identifier is now free for reuse.
      queuePids   = PacketIdentifier pid Seq.<| queuePids q
    , notComplete = IS.delete pid (notComplete q)
    }
  -- Although we did not enqueue something it might still be the case
  -- that we have unsent data that couldn't be sent by now because no more
  -- packet identifiers were available.
  -- In case the output queues are actually empty, the thread will check
  -- them once and immediately sleep again.
  notePending session

getSubscriptions :: Session auth -> IO (R.Trie QoS)
getSubscriptions session =
  readMVar (sessionSubscriptions session)

getConnection :: Session auth -> IO (Maybe Connection)
getConnection session =
  tryReadMVar (sessionConnection session)

getPrincipal  :: Session auth -> IO Principal
getPrincipal session =
  readMVar (sessionPrincipal session)

getFreePacketIdentifiers :: Session auth -> IO (Seq.Seq PacketIdentifier)
getFreePacketIdentifiers session =
  queuePids <$> readMVar (sessionQueue session)

resetQueue :: ServerQueue -> ServerQueue
resetQueue q = q {
    outputBuffer = (rePublishQoS1 . rePublishQoS2 . reReleaseQoS2) mempty
  }
  where
    rePublishQoS1 s = IM.foldlWithKey (\s' pid msg-> s' Seq.|> ServerPublish         (PacketIdentifier pid) (Duplicate True) msg) s (notAcknowledged q)
    rePublishQoS2 s = IM.foldlWithKey (\s' pid msg-> s' Seq.|> ServerPublish         (PacketIdentifier pid) (Duplicate True) msg) s (notReceived     q)
    reReleaseQoS2 s = IS.foldl        (\s' pid->     s' Seq.|> ServerPublishRelease  (PacketIdentifier pid)                     ) s (notComplete     q)

-- | This function fills the output buffer with as many messages
--   as possible (this is limited by the available packet identifiers).
normalizeQueue :: ServerQueue -> ServerQueue
normalizeQueue = takeQoS1 . takeQoS2
  where
    takeQoS1 q
      | Seq.null msgs     = q
      | otherwise         = q
        { outputBuffer    = outputBuffer q <> Seq.zipWith (flip ServerPublish (Duplicate False)) pids' msgs'
        , queuePids       = pids''
        , queueQoS1       = msgs''
        , notAcknowledged = foldr (\(PacketIdentifier pid, msg)-> IM.insert pid msg)
                                  (notAcknowledged q)
                                  (Seq.zipWith (,) pids' msgs')
        }
      where
        pids              = queuePids q
        msgs              = queueQoS1 q
        n                 = min (Seq.length pids) (Seq.length msgs)
        (pids', pids'')   = Seq.splitAt n pids
        (msgs', msgs'')   = Seq.splitAt n msgs
    takeQoS2 q
      | Seq.null msgs     = q
      | otherwise         = q
        { outputBuffer    = outputBuffer q <> Seq.zipWith (flip ServerPublish (Duplicate False)) pids' msgs'
        , queuePids       = pids''
        , queueQoS2       = msgs''
        , notReceived     = foldr (\(PacketIdentifier pid, msg)-> IM.insert pid msg)
                                  (notReceived q)
                                  (Seq.zipWith (,) pids' msgs')
        }
      where
        pids              = queuePids q
        msgs              = queueQoS2 q
        n                 = min (Seq.length pids) (Seq.length msgs)
        (pids', pids'')   = Seq.splitAt n pids
        (msgs', msgs'')   = Seq.splitAt n msgs