{-# LANGUAGE OverloadedStrings #-}
module WebsocketServer (
ServerState,
acceptConnection,
processUpdates
) where
import Control.Concurrent (modifyMVar_, readMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBQueue (readTBQueue)
import Control.Exception (SomeAsyncException, SomeException, finally, fromException, catch, throwIO)
import Control.Monad (forever)
import Data.Aeson (Value)
import Data.Text (Text)
import Data.UUID
import System.Random (randomIO)
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Time.Clock.POSIX as Clock
import qualified Network.WebSockets as WS
import qualified Network.HTTP.Types.Header as HttpHeader
import qualified Network.HTTP.Types.URI as Uri
import Config (Config (..))
import Core (Core (..), ServerState, Updated (..), getCurrentValue, withCoreMetrics)
import Store (Path)
import AccessControl (AccessMode(..))
import JwtMiddleware (AuthResult (..), isRequestAuthorized, errorResponseBody)
import qualified Metrics
import qualified Subscription
newUUID :: IO UUID
newUUID :: IO UUID
newUUID = IO UUID
forall a. Random a => IO a
randomIO
broadcast :: [Text] -> Value -> ServerState -> IO ()
broadcast :: [Text] -> Value -> ServerState -> IO ()
broadcast =
let
send :: WS.Connection -> Value -> IO ()
send :: Connection -> Value -> IO ()
send Connection
conn Value
value =
Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Value
value)
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
SomeException -> IO ()
sendFailed
sendFailed :: SomeException -> IO ()
sendFailed :: SomeException -> IO ()
sendFailed SomeException
exc
| Just SomeAsyncException
asyncExc <- SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exc = SomeAsyncException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (SomeAsyncException
asyncExc :: SomeAsyncException)
| Bool
otherwise = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
in
(Connection -> Value -> IO ())
-> [Text] -> Value -> ServerState -> IO ()
forall conn id.
(conn -> Value -> IO ())
-> [Text] -> Value -> SubscriptionTree id conn -> IO ()
Subscription.broadcast Connection -> Value -> IO ()
send
acceptConnection :: Core -> WS.PendingConnection -> IO ()
acceptConnection :: Core -> PendingConnection -> IO ()
acceptConnection Core
core PendingConnection
pending = do
AuthResult
authResult <- Core -> PendingConnection -> IO AuthResult
authorizePendingConnection Core
core PendingConnection
pending
case AuthResult
authResult of
AuthRejected AuthError
err ->
PendingConnection -> RejectRequest -> IO ()
WS.rejectRequestWith PendingConnection
pending (RejectRequest -> IO ()) -> RejectRequest -> IO ()
forall a b. (a -> b) -> a -> b
$ RejectRequest :: Int -> ByteString -> Headers -> ByteString -> RejectRequest
WS.RejectRequest
{ rejectCode :: Int
WS.rejectCode = Int
401
, rejectMessage :: ByteString
WS.rejectMessage = ByteString
"Unauthorized"
, rejectHeaders :: Headers
WS.rejectHeaders = [(HeaderName
HttpHeader.hContentType, ByteString
"application/json")]
, rejectBody :: ByteString
WS.rejectBody = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ AuthError -> ByteString
errorResponseBody AuthError
err
}
AuthResult
AuthAccepted -> do
let path :: [Text]
path = ([Text], Query) -> [Text]
forall a b. (a, b) -> a
fst (([Text], Query) -> [Text]) -> ([Text], Query) -> [Text]
forall a b. (a -> b) -> a -> b
$ ByteString -> ([Text], Query)
Uri.decodePath (ByteString -> ([Text], Query)) -> ByteString -> ([Text], Query)
forall a b. (a -> b) -> a -> b
$ RequestHead -> ByteString
WS.requestPath (RequestHead -> ByteString) -> RequestHead -> ByteString
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pending
Connection
connection <- PendingConnection -> IO Connection
WS.acceptRequest PendingConnection
pending
Connection -> Int -> IO () -> IO () -> IO ()
forall a. Connection -> Int -> IO () -> IO a -> IO a
WS.withPingThread Connection
connection Int
30 (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> [Text] -> Core -> IO ()
handleClient Connection
connection [Text]
path Core
core
authorizePendingConnection :: Core -> WS.PendingConnection -> IO AuthResult
authorizePendingConnection :: Core -> PendingConnection -> IO AuthResult
authorizePendingConnection Core
core PendingConnection
conn
| Config -> Bool
configEnableJwtAuth (Core -> Config
coreConfig Core
core) = do
POSIXTime
now <- IO POSIXTime
Clock.getPOSIXTime
let req :: RequestHead
req = PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
conn
([Text]
path, Query
query) = ByteString -> ([Text], Query)
Uri.decodePath (ByteString -> ([Text], Query)) -> ByteString -> ([Text], Query)
forall a b. (a -> b) -> a -> b
$ RequestHead -> ByteString
WS.requestPath RequestHead
req
headers :: Headers
headers = RequestHead -> Headers
WS.requestHeaders RequestHead
req
AuthResult -> IO AuthResult
forall (m :: * -> *) a. Monad m => a -> m a
return (AuthResult -> IO AuthResult) -> AuthResult -> IO AuthResult
forall a b. (a -> b) -> a -> b
$ Headers
-> Query
-> POSIXTime
-> Maybe Signer
-> [Text]
-> AccessMode
-> AuthResult
isRequestAuthorized Headers
headers Query
query POSIXTime
now (Config -> Maybe Signer
configJwtSecret (Core -> Config
coreConfig Core
core)) [Text]
path AccessMode
ModeRead
| Bool
otherwise = AuthResult -> IO AuthResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult
AuthAccepted
handleClient :: WS.Connection -> Path -> Core -> IO ()
handleClient :: Connection -> [Text] -> Core -> IO ()
handleClient Connection
conn [Text]
path Core
core = do
UUID
uuid <- IO UUID
newUUID
let
state :: MVar ServerState
state = Core -> MVar ServerState
coreClients Core
core
onConnect :: IO ()
onConnect = do
MVar ServerState -> (ServerState -> IO ServerState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ServerState
state (ServerState -> IO ServerState
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ServerState -> IO ServerState)
-> (ServerState -> ServerState) -> ServerState -> IO ServerState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> UUID -> Connection -> ServerState -> ServerState
forall id conn.
(Eq id, Hashable id) =>
[Text]
-> id
-> conn
-> SubscriptionTree id conn
-> SubscriptionTree id conn
Subscription.subscribe [Text]
path UUID
uuid Connection
conn)
Core -> (IcepeakMetrics -> IO ()) -> IO ()
forall (m :: * -> *).
MonadIO m =>
Core -> (IcepeakMetrics -> IO ()) -> m ()
withCoreMetrics Core
core IcepeakMetrics -> IO ()
forall (m :: * -> *). MonadMonitor m => IcepeakMetrics -> m ()
Metrics.incrementSubscribers
onDisconnect :: IO ()
onDisconnect = do
MVar ServerState -> (ServerState -> IO ServerState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ServerState
state (ServerState -> IO ServerState
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ServerState -> IO ServerState)
-> (ServerState -> ServerState) -> ServerState -> IO ServerState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> UUID -> ServerState -> ServerState
forall id conn.
(Eq id, Hashable id) =>
[Text]
-> id -> SubscriptionTree id conn -> SubscriptionTree id conn
Subscription.unsubscribe [Text]
path UUID
uuid)
Core -> (IcepeakMetrics -> IO ()) -> IO ()
forall (m :: * -> *).
MonadIO m =>
Core -> (IcepeakMetrics -> IO ()) -> m ()
withCoreMetrics Core
core IcepeakMetrics -> IO ()
forall (m :: * -> *). MonadMonitor m => IcepeakMetrics -> m ()
Metrics.decrementSubscribers
sendInitialValue :: IO ()
sendInitialValue = do
Maybe Value
currentValue <- Core -> [Text] -> IO (Maybe Value)
getCurrentValue Core
core [Text]
path
Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (Maybe Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Maybe Value
currentValue)
handleConnectionError :: WS.ConnectionException -> IO ()
handleConnectionError :: ConnectionException -> IO ()
handleConnectionError ConnectionException
_ = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
finally (IO ()
onConnect IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
sendInitialValue IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Connection -> IO ()
keepTalking Connection
conn) IO ()
onDisconnect
IO () -> (ConnectionException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` ConnectionException -> IO ()
handleConnectionError
keepTalking :: WS.Connection -> IO ()
keepTalking :: Connection -> IO ()
keepTalking Connection
conn = IO DataMessage -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO DataMessage -> IO ()) -> IO DataMessage -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Connection -> IO DataMessage
WS.receiveDataMessage Connection
conn
processUpdates :: Core -> IO ()
processUpdates :: Core -> IO ()
processUpdates Core
core = IO ()
go
where
go :: IO ()
go = do
Maybe Updated
maybeUpdate <- STM (Maybe Updated) -> IO (Maybe Updated)
forall a. STM a -> IO a
atomically (STM (Maybe Updated) -> IO (Maybe Updated))
-> STM (Maybe Updated) -> IO (Maybe Updated)
forall a b. (a -> b) -> a -> b
$ TBQueue (Maybe Updated) -> STM (Maybe Updated)
forall a. TBQueue a -> STM a
readTBQueue (Core -> TBQueue (Maybe Updated)
coreUpdates Core
core)
case Maybe Updated
maybeUpdate of
Just (Updated [Text]
path Value
value) -> do
ServerState
clients <- MVar ServerState -> IO ServerState
forall a. MVar a -> IO a
readMVar (Core -> MVar ServerState
coreClients Core
core)
[Text] -> Value -> ServerState -> IO ()
broadcast [Text]
path Value
value ServerState
clients
IO ()
go
Maybe Updated
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()