{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Network.MQTT.Broker.Internal
-- Copyright   :  (c) Lars Petersen 2016
-- License     :  MIT
--
-- Maintainer  :  info@lars-petersen.net
-- Stability   :  experimental
--------------------------------------------------------------------------------
module Network.MQTT.Broker.Internal where

import           Control.Concurrent.MVar
import           Control.Concurrent.PrioritySemaphore
import           Control.Monad
import qualified Data.Binary                           as B
import qualified Data.ByteString                       as BS
import           Data.Int
import qualified Data.IntMap.Strict                    as IM
import qualified Data.IntSet                           as IS
import qualified Data.Map.Strict                       as M
import qualified Data.Sequence                         as Seq
import           GHC.Generics                          (Generic)

import           Network.MQTT.Broker.Authentication    hiding (getPrincipal)
import qualified Network.MQTT.Broker.RetainedMessages  as RM
import qualified Network.MQTT.Broker.SessionStatistics as SS
import           Network.MQTT.Message
import qualified Network.MQTT.Trie                     as R

data Broker auth
   = Broker
   { brokerCreatedAt     :: Int64
   , brokerAuthenticator :: auth
   , brokerRetainedStore :: RM.RetainedStore
   , brokerState         :: MVar (BrokerState auth)
   }

data BrokerState auth
   = BrokerState
   { brokerMaxSessionIdentifier   :: !SessionIdentifier
   , brokerSubscriptions          :: !(R.Trie IS.IntSet)
   , brokerSessions               :: !(IM.IntMap (Session auth))
   , brokerSessionsByPrincipals   :: !(M.Map (PrincipalIdentifier, ClientIdentifier) SessionIdentifier)
   }

newtype SessionIdentifier = SessionIdentifier Int deriving (Eq, Ord, Show, Enum, Generic)

data Session auth
   = Session
   { sessionBroker              :: !(Broker auth)
   , sessionIdentifier          :: !SessionIdentifier
   , sessionClientIdentifier    :: !ClientIdentifier
   , sessionPrincipalIdentifier :: !PrincipalIdentifier
   , sessionCreatedAt           :: !Int64
   , sessionConnection          :: !(MVar Connection)
   , sessionPrincipal           :: !(MVar Principal)
   , sessionSemaphore           :: !PrioritySemaphore
   , sessionSubscriptions       :: !(MVar (R.Trie QoS))
   , sessionQueue               :: !(MVar ServerQueue)
   , sessionQueuePending        :: !(MVar ())
   , sessionStatistics          :: !SS.Statistics
   }

data Connection
   = Connection
   { connectionCreatedAt     :: !Int64
   , connectionCleanSession  :: !Bool
   , connectionSecure        :: !Bool
   , connectionWebSocket     :: !Bool
   , connectionRemoteAddress :: !(Maybe BS.ByteString)
   } deriving (Eq, Ord, Show, Generic)

data ServerQueue
   = ServerQueue
   { queuePids       :: !(Seq.Seq PacketIdentifier)
   , outputBuffer    :: !(Seq.Seq ServerPacket)
   , queueQoS0       :: !(Seq.Seq Message)
   , queueQoS1       :: !(Seq.Seq Message)
   , queueQoS2       :: !(Seq.Seq Message)
   , notAcknowledged :: !(IM.IntMap Message) -- We sent a `QoS1` message and have not yet received the @PUBACK@.
   , notReceived     :: !(IM.IntMap Message) -- We sent a `QoS2` message and have not yet received the @PUBREC@.
   , notReleased     :: !(IM.IntMap Message) -- We received as `QoS2` message, sent the @PUBREC@ and wait for the @PUBREL@.
   , notComplete     :: !IS.IntSet           -- We sent a @PUBREL@ and have not yet received the @PUBCOMP@.
   }

instance B.Binary SessionIdentifier
instance B.Binary Connection

instance Eq (Session auth) where
 (==) s1 s2 = (==) (sessionIdentifier s1) (sessionIdentifier s2)

instance Ord (Session auth) where
 compare s1 s2 = compare (sessionIdentifier s1) (sessionIdentifier s2)

instance Show (Session auth) where
 show session =
   "Session { identifier = " ++ show (sessionIdentifier session)
   ++ ", principal = " ++ show (sessionPrincipalIdentifier session)
   ++ ", client = " ++ show (sessionClientIdentifier session) ++ " }"

-- | Inject a message downstream into the broker. It will be delivered
--   to all subscribed sessions within this broker instance.
publishDownstream :: Broker auth -> Message -> IO ()
publishDownstream broker msg = do
 RM.store msg (brokerRetainedStore broker)
 let topic = msgTopic msg
 st <- readMVar (brokerState broker)
 forM_ (IS.elems $ R.lookup topic $ brokerSubscriptions st) $ \key->
   case IM.lookup (key :: Int) (brokerSessions st) of
     Nothing      ->
       putStrLn "WARNING: dead session reference"
     Just session -> publishMessage session msg

-- | Publish a message upstream on the broker.
--
--   * As long as the broker is not clustered upstream=downstream.
--   * FUTURE NOTE: In clustering mode this shall distribute the message
--     to other brokers or upwards when the brokers form a hierarchy.
publishUpstream :: Broker auth -> Message -> IO ()
publishUpstream = publishDownstream

notePending   :: Session auth -> IO ()
notePending    = void . flip tryPutMVar () . sessionQueuePending

waitPending   :: Session auth -> IO ()
waitPending    = void . readMVar . sessionQueuePending

emptyServerQueue :: Int -> ServerQueue
emptyServerQueue i = ServerQueue
 { queuePids         = Seq.fromList $ fmap PacketIdentifier [0 .. min i 65535]
 , outputBuffer      = mempty
 , queueQoS0         = mempty
 , queueQoS1         = mempty
 , queueQoS2         = mempty
 , notAcknowledged   = mempty
 , notReceived       = mempty
 , notReleased       = mempty
 , notComplete       = mempty
 }

publishMessage :: Session auth -> Message -> IO ()
publishMessage session msg = do
 subscriptions <- readMVar (sessionSubscriptions session)
 case R.findMaxBounded (msgTopic msg) subscriptions of
   Nothing  -> pure ()
   Just qos -> enqueueMessage session msg { msgQoS = qos }

publishMessages :: Foldable t => Session auth -> t Message -> IO ()
publishMessages session msgs =
 forM_ msgs (publishMessage session)

-- | This enqueues a message for transmission to the client.
--
--   * This operations eventually terminates the session on queue overflow.
--     The caller will not notice this and the operation will not throw an exception.
enqueueMessage :: Session auth -> Message -> IO ()
enqueueMessage session msg = do
  quota <- principalQuota <$> readMVar (sessionPrincipal session)
  success <- modifyMVar (sessionQueue session) $ \queue-> case msgQoS msg of
    QoS0 -> if (fromIntegral $ quotaMaxQueueSizeQoS0 quota) > Seq.length (queueQoS0 queue)
      then pure $ (, True) $! queue { queueQoS0 = queueQoS0 queue Seq.|> msg }
      else pure $ (, True) $! queue { queueQoS0 = Seq.drop 1 $ queueQoS0 queue Seq.|> msg }
    QoS1 -> if (fromIntegral $ quotaMaxQueueSizeQoS1 quota) > Seq.length (queueQoS1 queue)
      then pure $ (, True) $! queue { queueQoS1 = queueQoS1 queue Seq.|> msg }
      else pure (queue, False)
    QoS2 -> if (fromIntegral $ quotaMaxQueueSizeQoS2 quota) > Seq.length (queueQoS2 queue)
      then pure $ (, True) $! queue { queueQoS2 = queueQoS2 queue Seq.|> msg }
      else pure (queue, False)
  if success
    -- Notify the sending thread that something has been enqueued!
    then notePending session
    -- Kill the session.
    else terminate session

-- TODO: make more efficient
enqueueMessages :: Foldable t => Session auth -> t Message -> IO ()
enqueueMessages session msgs =
 forM_ msgs (enqueueMessage session)

-- | Terminate a session.
--
--   * An eventually connected client gets disconnected.
--   * The session subscriptions are removed from the subscription tree
--     which means that it will receive no more messages.
--   * The session will be unlinked from the broker which means
--     that clients cannot resume it anymore under this client identifier.
terminate :: Session auth -> IO ()
terminate session =
  -- This assures that the client gets disconnected. The code is executed
  -- _after_ the current client handler that has terminated.
  -- TODO Race: New clients may try to connect while we are in here.
  -- This would not make the state inconsistent, but kill this thread.
  -- What we need is another `exclusivelyUninterruptible` function for
  -- the priority semaphore.
  exclusively (sessionSemaphore session) $
    modifyMVarMasked_ (brokerState $ sessionBroker session) $ \st->
      withMVarMasked (sessionSubscriptions session) $ \subscriptions->
        pure st
          { brokerSessions = IM.delete sid ( brokerSessions st )
            -- Remove the session id from the (principal, client) -> sessionid
            -- mapping. Remove empty leaves in this mapping, too.
          , brokerSessionsByPrincipals = M.delete
              ( sessionPrincipalIdentifier session
              , sessionClientIdentifier session)
              (brokerSessionsByPrincipals st)
            -- Remove the session id from each set that the session
            -- subscription tree has a corresponding value for (which is ignored).
          , brokerSubscriptions = R.differenceWith
              (\b _-> Just (IS.delete sid b) )
              ( brokerSubscriptions st ) subscriptions
          }
  where
    SessionIdentifier sid = sessionIdentifier session