module Network.MQTT.Internal
(
Config(..)
, Terminated(..)
, Command(..)
, Commands(..)
, mkCommands
, send
, await
, AwaitMessage(..)
, stopWaiting
, sendAwait
, writeCmd
, mainLoop
, WaitTerminate
, SendSignal
, MqttState(..)
, ParseCC
, Input(..)
, waitForInput
, parseBytes
, handleMessage
, publishHandler
, secToMicro
) where
import Control.Applicative ((<$>))
import Control.Concurrent
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM
import Control.Exception (bracketOnError)
import Control.Monad (void, forever, filterM)
import Control.Monad.IO.Class (liftIO, MonadIO)
import Control.Monad.Loops (untilJust)
import Control.Monad.State.Strict (evalStateT, gets, modify, StateT)
import Data.Attoparsec.ByteString (IResult(..) , parse)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Foldable (for_)
import Data.Maybe (isNothing, fromMaybe)
import Data.Singletons (SingI(..))
import Data.Singletons.Decide
import Data.Text (Text)
import Data.Word (Word16)
import Network
import System.IO (Handle, hLookAhead)
import System.Timeout (timeout)
import Network.MQTT.Types
import Network.MQTT.Parser (message)
import Network.MQTT.Encoding (writeTo)
data Terminated
= ParseFailed [String] String
| ConnectFailed ConnectError
| UserRequested
deriving Show
data Command
= CmdDisconnect
| CmdSend SomeMessage
| CmdAwait AwaitMessage
| CmdStopWaiting AwaitMessage
newtype Commands = Cmds { getCmds :: TChan Command }
mkCommands :: IO Commands
mkCommands = Cmds <$> newTChanIO
data AwaitMessage where
AwaitMessage :: SingI t => MVar (Message t) -> Maybe MsgID -> AwaitMessage
instance Eq AwaitMessage where
AwaitMessage (var :: MVar (Message t)) mMsgID == AwaitMessage (var' :: MVar (Message t')) mMsgID' =
case (sing :: SMsgType t) %~ (sing :: SMsgType t') of
Proved Refl -> mMsgID == mMsgID' && var == var'
Disproved _ -> False
data Config
= Config
{ cHost :: HostName
, cPort :: PortNumber
, cClean :: Bool
, cWill :: Maybe Will
, cUsername :: Maybe Text
, cPassword :: Maybe Text
, cKeepAlive :: Maybe Word16
, cClientID :: Text
, cLogDebug :: String -> IO ()
, cResendTimeout :: Int
, cPublished :: TChan (Message 'PUBLISH)
, cCommands :: Commands
, cInputBufferSize :: Int
}
send :: SingI t => Config -> Message t -> IO ()
send mqtt = writeCmd mqtt . CmdSend . SomeMessage
await :: SingI t => Config -> MVar (Message t) -> Maybe MsgID
-> IO AwaitMessage
await mqtt var mMsgID = do
writeCmd mqtt $ CmdAwait awaitMsg
return awaitMsg
where
awaitMsg = AwaitMessage var mMsgID
stopWaiting :: Config -> AwaitMessage -> IO ()
stopWaiting mqtt = writeCmd mqtt . CmdStopWaiting
sendAwait :: (SingI t, SingI r)
=> Config -> Message t -> SMsgType r -> IO (Message r)
sendAwait mqtt msg _responseS = do
var <- newEmptyMVar
let mMsgID = getMsgID (body msg)
bracketOnError
(await mqtt var mMsgID)
(stopWaiting mqtt)
(\_ ->
let wait = do
received <- readMVar var
if isNothing mMsgID || mMsgID == getMsgID (body received)
then return received
else wait
keepTrying msg' tout = do
send mqtt msg'
let retransmit = do
cLogDebug mqtt "No response within timeout, retransmitting..."
keepTrying (setDup msg') (tout * 2)
timeout tout wait >>= maybe retransmit return
in keepTrying msg initialTout)
where
initialTout = secToMicro $ cResendTimeout mqtt
type ParseCC = ByteString -> IResult ByteString SomeMessage
data MqttState
= MqttState
{ msParseCC :: ParseCC
, msUnconsumed :: BS.ByteString
, msWaiting :: [AwaitMessage]
}
data Input
= InMsg SomeMessage
| InErr Terminated
| InCmd Command
type WaitTerminate = STM ()
type SendSignal = MVar ()
mainLoop :: Config -> Handle -> WaitTerminate -> SendSignal -> IO Terminated
mainLoop mqtt h waitTerminate sendSignal = do
void $ forkMQTT waitTerminate $ keepAliveLoop mqtt sendSignal
evalStateT
(handshake >>= maybe (liftIO (cLogDebug mqtt "Connected") >> go) return)
(MqttState (parse message) BS.empty [])
where
go = do
input <- waitForInput mqtt h
case input of
InErr err -> liftIO $
return err
InMsg someMsg -> do
liftIO $ cLogDebug mqtt $ "Received " ++ show (toMsgType' someMsg)
handleMessage mqtt waitTerminate someMsg
go
InCmd cmd -> case cmd of
CmdDisconnect -> liftIO $ do
doSend msgDisconnect
return UserRequested
CmdSend (SomeMessage msg) -> do
doSend msg
go
CmdAwait awaitMsg -> do
modify $ \s -> s { msWaiting = awaitMsg : msWaiting s }
go
CmdStopWaiting awaitMsg -> do
modify $ \s -> s { msWaiting = filter (== awaitMsg) $ msWaiting s }
go
handshake :: StateT MqttState IO (Maybe Terminated)
handshake = do
doSend msgConnect
input <- untilJust (getSome mqtt h >>= parseBytes)
case input of
InErr err -> return $ Just err
InMsg someMsg -> return $ case someMsg of
SomeMessage (Message _ (ConnAck retCode)) ->
if retCode /= 0
then Just $ ConnectFailed $ toConnectError retCode
else Nothing
_ -> Just $ ConnectFailed InvalidResponse
InCmd _ -> error "parseBytes returned InCmd, this should not happen."
where
msgConnect = Message
(Header False NoConfirm False)
(Connect
(cClean mqtt)
(cWill mqtt)
(MqttText $ cClientID mqtt)
(MqttText <$> cUsername mqtt)
(MqttText <$> cPassword mqtt)
(fromMaybe 0 $ cKeepAlive mqtt))
msgDisconnect = Message (Header False NoConfirm False) Disconnect
doSend :: (MonadIO io, SingI t) => Message t -> io ()
doSend msg = liftIO $ do
cLogDebug mqtt $ "Sending " ++ show (toMsgType msg)
writeTo h msg
void $ tryPutMVar sendSignal ()
waitForInput :: Config -> Handle -> StateT MqttState IO Input
waitForInput mqtt h = do
let cmdChan = getCmds $ cCommands mqtt
unconsumed <- gets msUnconsumed
if BS.null unconsumed
then do
input <- liftIO $ Async.race
(void $ hLookAhead h)
(void $ atomically $ peekTChan cmdChan)
case input of
Left () -> getSome mqtt h >>= parseUntilDone
Right () -> InCmd <$> liftIO (atomically (readTChan cmdChan))
else
parseUntilDone unconsumed
where
parseUntilDone bytes = parseBytes bytes >>= maybe (waitForInput mqtt h) return
parseBytes :: Monad m => ByteString -> StateT MqttState m (Maybe Input)
parseBytes bytes = do
parseCC <- gets msParseCC
case parseCC bytes of
Fail _unconsumed context err ->
return $ Just $ InErr $ ParseFailed context err
Partial cont -> do
modify $ \s -> s { msParseCC = cont
, msUnconsumed = BS.empty }
return Nothing
Done unconsumed someMsg -> do
modify $ \s -> s { msParseCC = parse message
, msUnconsumed = unconsumed }
return $ Just $ InMsg someMsg
handleMessage :: Config -> WaitTerminate -> SomeMessage -> StateT MqttState IO ()
handleMessage mqtt waitTerminate (SomeMessage msg) =
case toSMsgType msg %~ SPUBLISH of
Proved Refl -> liftIO $ void $ forkMQTT waitTerminate $ publishHandler mqtt msg
Disproved _ -> do
waiting' <- gets msWaiting >>= liftIO . filterM giveToWaiting
modify (\s -> s { msWaiting = waiting' })
where
giveToWaiting :: AwaitMessage -> IO Bool
giveToWaiting (AwaitMessage (var :: MVar (Message t')) mMsgID')
| isNothing mMsgID || mMsgID == mMsgID' =
case toSMsgType msg %~ (sing :: SMsgType t') of
Proved Refl -> putMVar var msg >> return False
Disproved _ -> return True
| otherwise = return True
mMsgID = getMsgID (body msg)
keepAliveLoop :: Config -> SendSignal -> IO ()
keepAliveLoop mqtt signal = for_ (cKeepAlive mqtt) $ \tout -> forever $ do
rslt <- timeout (secToMicro (fromIntegral tout)) (takeMVar signal)
case rslt of
Nothing -> void $ sendAwait mqtt
(Message (Header False NoConfirm False) PingReq)
SPINGRESP
Just _ -> return ()
publishHandler :: Config -> Message 'PUBLISH -> IO ()
publishHandler mqtt msg = do
case (qos (header msg), pubMsgID (body msg)) of
(Confirm, Just msgid) -> do
release
send mqtt $ Message (Header False NoConfirm False) (PubAck msgid)
(Handshake, Just msgid) -> do
_ <- sendAwait mqtt
(Message (Header False NoConfirm False) (PubRec msgid))
SPUBREL
release
send mqtt $ Message (Header False NoConfirm False) (PubComp msgid)
_ -> release
where
release = writeTChanIO (cPublished mqtt) msg
getSome :: MonadIO m => Config -> Handle -> m ByteString
getSome mqtt h = liftIO (BS.hGetSome h (cInputBufferSize mqtt))
forkMQTT :: WaitTerminate -> IO () -> IO (Async.Async ())
forkMQTT waitTerminate action = Async.async $ Async.withAsync action $ \forked ->
atomically $ waitTerminate `orElse` Async.waitSTM forked
writeTChanIO :: TChan a -> a -> IO ()
writeTChanIO chan = atomically . writeTChan chan
writeCmd :: Config -> Command -> IO ()
writeCmd mqtt = writeTChanIO (getCmds $ cCommands mqtt)
secToMicro :: Int -> Int
secToMicro m = m * 10 ^ (6 :: Int)