{-# LANGUAGE OverloadedStrings #-}
module WebsocketServer (
ServerState,
acceptConnection,
processUpdates
) where
import Control.Concurrent (modifyMVar_, readMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBQueue (readTBQueue)
import Control.Exception (SomeAsyncException, SomeException, finally, fromException, catch, throwIO)
import Control.Monad (forever)
import Data.Aeson (Value)
import Data.Text (Text)
import Data.UUID
import System.Random (randomIO)
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Time.Clock.POSIX as Clock
import qualified Network.WebSockets as WS
import qualified Network.HTTP.Types.Header as HttpHeader
import qualified Network.HTTP.Types.URI as Uri
import Config (Config (..))
import Core (Core (..), ServerState, Updated (..), getCurrentValue, withCoreMetrics)
import Store (Path)
import AccessControl (AccessMode(..))
import JwtMiddleware (AuthResult (..), isRequestAuthorized, errorResponseBody)
import qualified Metrics
import qualified Subscription
newUUID :: IO UUID
newUUID = randomIO
broadcast :: [Text] -> Value -> ServerState -> IO ()
broadcast =
let
send :: WS.Connection -> Value -> IO ()
send conn value =
WS.sendTextData conn (Aeson.encode value)
`catch`
sendFailed
sendFailed :: SomeException -> IO ()
sendFailed exc
| Just asyncExc <- fromException exc = throwIO (asyncExc :: SomeAsyncException)
| otherwise = pure ()
in
Subscription.broadcast send
acceptConnection :: Core -> WS.PendingConnection -> IO ()
acceptConnection core pending = do
authResult <- authorizePendingConnection core pending
case authResult of
AuthRejected err ->
WS.rejectRequestWith pending $ WS.RejectRequest
{ WS.rejectCode = 401
, WS.rejectMessage = "Unauthorized"
, WS.rejectHeaders = [(HttpHeader.hContentType, "application/json")]
, WS.rejectBody = LBS.toStrict $ errorResponseBody err
}
AuthAccepted -> do
let path = fst $ Uri.decodePath $ WS.requestPath $ WS.pendingRequest pending
connection <- WS.acceptRequest pending
WS.withPingThread connection 30 (pure ()) $ handleClient connection path core
authorizePendingConnection :: Core -> WS.PendingConnection -> IO AuthResult
authorizePendingConnection core conn
| configEnableJwtAuth (coreConfig core) = do
now <- Clock.getPOSIXTime
let req = WS.pendingRequest conn
(path, query) = Uri.decodePath $ WS.requestPath req
headers = WS.requestHeaders req
return $ isRequestAuthorized headers query now (configJwtSecret (coreConfig core)) path ModeRead
| otherwise = pure AuthAccepted
handleClient :: WS.Connection -> Path -> Core -> IO ()
handleClient conn path core = do
uuid <- newUUID
let
state = coreClients core
onConnect = do
modifyMVar_ state (pure . Subscription.subscribe path uuid conn)
withCoreMetrics core Metrics.incrementSubscribers
onDisconnect = do
modifyMVar_ state (pure . Subscription.unsubscribe path uuid)
withCoreMetrics core Metrics.decrementSubscribers
sendInitialValue = do
currentValue <- getCurrentValue core path
WS.sendTextData conn (Aeson.encode currentValue)
handleConnectionError :: WS.ConnectionException -> IO ()
handleConnectionError _ = pure ()
finally (onConnect >> sendInitialValue >> keepTalking conn) onDisconnect
`catch` handleConnectionError
keepTalking :: WS.Connection -> IO ()
keepTalking conn = forever $ do
WS.receiveDataMessage conn
processUpdates :: Core -> IO ()
processUpdates core = go
where
go = do
maybeUpdate <- atomically $ readTBQueue (coreUpdates core)
case maybeUpdate of
Just (Updated path value) -> do
clients <- readMVar (coreClients core)
broadcast path value clients
go
Nothing -> pure ()