module Network.Mattermost.WebSocket
( MMWebSocket
, MMWebSocketTimeoutException
, mmWithWebSocket
, mmCloseWebSocket
, mmGetConnectionHealth
, module Network.Mattermost.WebSocket.Types
) where
import Control.Concurrent (ThreadId, forkIO, myThreadId, threadDelay)
import qualified Control.Concurrent.STM.TQueue as Queue
import Control.Exception (Exception, SomeException, catch, throwIO, throwTo, try)
import Control.Monad (forever)
import Control.Monad.STM (atomically)
import Data.Aeson (toJSON)
import qualified Data.ByteString.Char8 as B
import Data.ByteString.Lazy (toStrict)
import Data.IORef
import Data.Monoid ((<>))
import qualified Data.Text as T
import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime)
import Data.Typeable ( Typeable )
import Network.Connection ( Connection
, connectionClose
, connectionGet
, connectionPut
)
import qualified Network.WebSockets as WS
import Network.WebSockets.Stream (Stream, makeStream)
import Network.Mattermost.Util
import Network.Mattermost.Types.Base
import Network.Mattermost.Types.Internal
import Network.Mattermost.Types
import Network.Mattermost.WebSocket.Types
connectionToStream :: Connection -> IO Stream
connectionToStream con = makeStream rd wr
where wr Nothing = connectionClose con
wr (Just bs) = connectionPut con (toStrict bs)
rd = do
bs <- connectionGet con 1024
return $ if B.null bs
then Nothing
else Just bs
data MMWebSocket = MMWS WS.Connection (IORef NominalDiffTime)
data MMWebSocketTimeoutException = MMWebSocketTimeoutException
deriving (Show, Typeable)
instance Exception MMWebSocketTimeoutException where
data PEvent = P UTCTime
createPingPongTimeouts :: ThreadId
-> IORef NominalDiffTime
-> Int
-> (LogEventType -> IO ())
-> IO (IO (), IO (), ThreadId)
createPingPongTimeouts pId health n doLog = do
pingChan <- Queue.newTQueueIO
pongChan <- Queue.newTQueueIO
let pingAction = do
now <- getCurrentTime
doLog WebSocketPing
atomically $ Queue.writeTQueue pingChan (P now)
let pongAction = do
now <- getCurrentTime
doLog WebSocketPong
atomically $ Queue.writeTQueue pongChan (P now)
watchdogPId <- forkIO $ do
let go = do
P old <- atomically $ Queue.readTQueue pingChan
threadDelay (n * 1000 * 1000)
b <- atomically $ Queue.isEmptyTQueue pongChan
if b
then throwTo pId MMWebSocketTimeoutException
else do
P new <- atomically $ Queue.readTQueue pongChan
atomicWriteIORef health (new `diffUTCTime` old)
go
go
return (pingAction, pongAction, watchdogPId)
mmCloseWebSocket :: MMWebSocket -> IO ()
mmCloseWebSocket (MMWS c _) = WS.sendClose c B.empty
mmGetConnectionHealth :: MMWebSocket -> IO NominalDiffTime
mmGetConnectionHealth (MMWS _ h) = readIORef h
pingThread :: IO () -> WS.Connection -> IO ()
pingThread onPingAction conn = loop 0
where loop :: Int -> IO ()
loop n = do
threadDelay (10 * 1000 * 1000)
onPingAction
WS.sendPing conn (B.pack (show n))
loop (n+1)
mmWithWebSocket :: Session
-> (Either String WebsocketEvent -> IO ())
-> (MMWebSocket -> IO ())
-> IO ()
mmWithWebSocket (Session cd (Token tk)) recv body = do
con <- mkConnection cd
stream <- connectionToStream con
health <- newIORef 0
myId <- myThreadId
let doLog = runLogger cd "websocket"
(onPing, onPong, _) <- createPingPongTimeouts myId health 8 doLog
let action c = do
pId <- forkIO (pingThread onPing c `catch` cleanup)
mId <- forkIO $ flip catch cleanup $ forever $ do
result <- try $ do
v <- WS.receiveData c
v `seq` return v
val <- case result of
Left (WS.ParseException e) -> return $ Left e
Left e -> throwIO e
Right ws -> return $ Right ws
doLog (WebSocketResponse $ case val of
Left s -> Left s
Right v -> Right $ toJSON v)
recv val
body (MMWS c health) `catch` propagate [mId, pId]
WS.runClientWithStream stream
(T.unpack $ cdHostname cd)
"/api/v3/users/websocket"
WS.defaultConnectionOptions { WS.connectionOnPong = onPong }
[ ("Authorization", "Bearer " <> B.pack tk) ]
action
where cleanup :: SomeException -> IO ()
cleanup _ = return ()
propagate :: [ThreadId] -> SomeException -> IO ()
propagate ts e = do
sequence_ [ throwTo t e | t <- ts ]
throwIO e