{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} module Examples.Raft.Socket.Node where import Protolude import Control.Monad.Fail import Control.Monad.Catch import Control.Monad.Trans.Class import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TVar import Control.Concurrent.STM.TMVar import qualified Data.Map as Map import qualified Data.Serialize as S import qualified Network.Simple.TCP as NS import Network.Simple.TCP (HostName, ServiceName) import Examples.Raft.Socket.Common import Raft.Client import Raft.Event import Raft.Log import Raft.Monad import Raft.Persistent import Raft.RPC import Raft.StateMachine import Raft.Types -------------------------------------------------------------------------------- -- Network -------------------------------------------------------------------------------- data ResponseSignal sm v = OkResponse (ClientResponse sm v) -- ^ we managed to write a valid response to the @TMVar@ | DeadResponse -- ^ if we get overlapping requests coming in with the same client -- id, we "kill" one of them data NodeSocketEnv sm v = NodeSocketEnv { nsMsgQueue :: TChan (RPCMessage v) -- ^ Queue of RPC messages to be processed by event handlers , nsClientReqQueue :: TChan (ClientRequest v) -- ^ Queue of client request messages to be processed by event handlers , nsClientReqResps :: TVar (Map ClientId (TMVar (ResponseSignal sm v))) -- ^ Map of variables to which responses to a request are written. N.B.: -- this assumes a client id uniquely identifies a request; A client -- will never send a request without having either 1) given up on the -- a previous request because of a timeout, or 2) received a response to -- each previous request issued. } newtype RaftSocketT sm v m a = RaftSocketT { unRaftSocketT :: ReaderT (NodeSocketEnv sm v) m a } deriving (Functor, Applicative, Monad, MonadIO, MonadFail, MonadReader (NodeSocketEnv sm v), Alternative, MonadPlus, MonadTrans) deriving instance MonadThrow m => MonadThrow (RaftSocketT sm v m) deriving instance MonadCatch m => MonadCatch (RaftSocketT sm v m) deriving instance MonadMask m => MonadMask (RaftSocketT sm v m) -------------------- -- Raft Instances -- -------------------- instance (RaftStateMachinePure sm v, MonadMask m, MonadCatch m, MonadIO m, S.Serialize sm, S.Serialize v) => RaftSendClient (RaftSocketT sm v m) sm v where sendClient clientId msg = do NodeSocketEnv{..} <- ask mRespVar <- liftIO . atomically . fmap (Map.lookup clientId) . readTVar $ nsClientReqResps -- We write the response to the TMVar corresponding to the client -- id, such that @acceptConnections@ can send it back to the -- client. case mRespVar of Nothing -> liftIO $ putText "sendClient: response lookup failed" Just respVar -> liftIO . atomically . putTMVar respVar . OkResponse $ msg instance (MonadIO m, S.Serialize v) => RaftRecvClient (RaftSocketT sm v m) v where type RaftRecvClientError (RaftSocketT sm v m) v = Text receiveClient = do cReq <- asks nsClientReqQueue fmap Right . liftIO . atomically $ readTChan cReq instance (MonadCatch m, MonadMask m, MonadIO m, S.Serialize v, Show v) => RaftSendRPC (RaftSocketT sm v m) v where sendRPC nid msg = do eRes <- Control.Monad.Catch.try $ NS.connect host port $ \(sock,_) -> do NS.send sock (S.encode $ RPCMessageEvent msg) case eRes of Left (err :: SomeException) -> putText ("Failed to send RPC: " <> show err) Right _ -> pure () where (host, port) = nidToHostPort nid instance (MonadIO m, Show v) => RaftRecvRPC (RaftSocketT sm v m) v where type RaftRecvRPCError (RaftSocketT sm v m) v = Text receiveRPC = do msgQueue <- asks nsMsgQueue fmap Right . liftIO . atomically $ readTChan msgQueue runRaftSocketT :: MonadIO m => NodeSocketEnv sm v -> RaftSocketT sm v m a -> m a runRaftSocketT nodeSocketEnv = flip runReaderT nodeSocketEnv . unRaftSocketT acceptConnections :: forall sm v m. ( S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v) , Show (RaftStateMachinePureError sm v) , MonadIO m ) => HostName -> ServiceName -> RaftSocketT sm v m () acceptConnections host port = do NodeSocketEnv{..} <- ask NS.serve (NS.Host host) port $ \(sock, _) -> do mVal <- recvSerialized sock case mVal of Nothing -> putText "Socket was closed on the other end" Just (ClientRequestEvent req@(ClientRequest clientId _)) -> do -- Create and register TMVar where the response should be -- written to for this client id in the 'sendClient' impl respVar <- atomically $ do newRespVar <- newEmptyTMVar clientReqResps <- readTVar nsClientReqResps -- If there's an outstanding request for this -- client id, we send a signal that the request -- is "dead". Given sufficiently unique client -- ids, this case should never occur. when (Map.member clientId clientReqResps) $ case Map.lookup clientId clientReqResps of Nothing -> pure () Just reqVar -> putTMVar reqVar DeadResponse writeTVar nsClientReqResps . Map.insert clientId newRespVar $ clientReqResps pure newRespVar -- Register request to be handled by event handler atomically $ writeTChan nsClientReqQueue req -- Wait until response has been written to the TMVar by a 'sendClient' -- call and send the response to the client. respMsg <- atomically $ takeTMVar respVar case respMsg of OkResponse okResp -> NS.send sock (S.encode okResp) DeadResponse -> pure () -- ignored for now -- Remove response variable for the client id atomically $ do modifyTVar nsClientReqResps $ \_nsClientReqResps -> Map.delete clientId _nsClientReqResps Just (RPCMessageEvent msg) -> atomically $ writeTChan nsMsgQueue msg -------------------------------------------------------------------------------- -- Inherited Instances -------------------------------------------------------------------------------- instance (MonadIO m, RaftPersist m) => RaftPersist (RaftSocketT sm v m) where type RaftPersistError (RaftSocketT sm v m) = RaftPersistError m initializePersistentState = lift initializePersistentState writePersistentState ps = lift $ writePersistentState ps readPersistentState = lift readPersistentState instance (MonadIO m, RaftInitLog m v) => RaftInitLog (RaftSocketT sm v m) v where type RaftInitLogError (RaftSocketT sm v m) = RaftInitLogError m initializeLog p = lift $ initializeLog p instance RaftWriteLog m v => RaftWriteLog (RaftSocketT sm v m) v where type RaftWriteLogError (RaftSocketT sm v m) = RaftWriteLogError m writeLogEntries entries = lift $ writeLogEntries entries instance RaftReadLog m v => RaftReadLog (RaftSocketT sm v m) v where type RaftReadLogError (RaftSocketT sm v m) = RaftReadLogError m readLogEntry idx = lift $ readLogEntry idx readLastLogEntry = lift readLastLogEntry instance RaftDeleteLog m v => RaftDeleteLog (RaftSocketT sm v m) v where type RaftDeleteLogError (RaftSocketT sm v m) = RaftDeleteLogError m deleteLogEntriesFrom idx = lift $ deleteLogEntriesFrom idx instance RaftStateMachine m sm v => RaftStateMachine (RaftSocketT sm v m) sm v where validateCmd = lift . validateCmd askRaftStateMachinePureCtx = lift askRaftStateMachinePureCtx instance MonadRaftChan v m => MonadRaftChan v (RaftSocketT sm v m) where type RaftEventChan v (RaftSocketT sm v m) = RaftEventChan v m readRaftChan = lift . readRaftChan writeRaftChan chan = lift . writeRaftChan chan newRaftChan = lift (newRaftChan @v @m) instance (MonadIO m, MonadRaftFork m) => MonadRaftFork (RaftSocketT sm v m) where type RaftThreadId (RaftSocketT sm v m) = RaftThreadId m raftFork r m = do persistFile <- ask lift $ raftFork r (runRaftSocketT persistFile m)