{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Network.GRPC.Server.Handlers where
import Control.Concurrent.Async (concurrently)
import Control.Monad (void)
import Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Lazy (toStrict)
import Data.ProtoLens.Message (Message)
import Data.ProtoLens.Service.Types (Service(..), HasMethod, HasMethodImpl(..), StreamingType(..))
import Network.GRPC.HTTP2.Encoding (decodeInput, encodeOutput, Encoding(..), Decoding(..))
import Network.GRPC.HTTP2.Types (RPC(..), GRPCStatus(..), GRPCStatusCode(..), path)
import Network.Wai (Request, requestBody, strictRequestBody)
import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)
type UnaryHandler s m = Request -> MethodInput s m -> IO (MethodOutput s m)
type ServerStreamHandler s m a = Request -> MethodInput s m -> IO (a, ServerStream s m a)
newtype ServerStream s m a = ServerStream {
serverStreamNext :: a -> IO (Maybe (a, MethodOutput s m))
}
type ClientStreamHandler s m a = Request -> IO (a, ClientStream s m a)
data ClientStream s m a = ClientStream {
clientStreamHandler :: a -> MethodInput s m -> IO a
, clientStreamFinalizer :: a -> IO (MethodOutput s m)
}
type BiDiStreamHandler s m a = Request -> IO (a, BiDiStream s m a)
data BiDiStep s m a
= Abort
| WaitInput !(a -> MethodInput s m -> IO a) !(a -> IO a)
| WriteOutput !a (MethodOutput s m)
data BiDiStream s m a = BiDiStream {
bidirNextStep :: a -> IO (BiDiStep s m a)
}
unary
:: (Service s, HasMethod s m)
=> RPC s m
-> UnaryHandler s m
-> ServiceHandler
unary rpc handler =
ServiceHandler (path rpc) (handleUnary rpc handler)
serverStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'ServerStreaming)
=> RPC s m
-> ServerStreamHandler s m a
-> ServiceHandler
serverStream rpc handler =
ServiceHandler (path rpc) (handleServerStream rpc handler)
clientStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'ClientStreaming)
=> RPC s m
-> ClientStreamHandler s m a
-> ServiceHandler
clientStream rpc handler =
ServiceHandler (path rpc) (handleClientStream rpc handler)
bidiStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'BiDiStreaming)
=> RPC s m
-> BiDiStreamHandler s m a
-> ServiceHandler
bidiStream rpc handler =
ServiceHandler (path rpc) (handleBiDiStream rpc handler)
generalStream
:: (Service s, HasMethod s m)
=> RPC s m
-> GeneralStreamHandler s m a b
-> ServiceHandler
generalStream rpc handler =
ServiceHandler (path rpc) (handleGeneralStream rpc handler)
handleUnary ::
(Service s, HasMethod s m)
=> RPC s m
-> UnaryHandler s m
-> WaiHandler
handleUnary rpc handler decoding encoding req write flush = do
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (\i -> handler req i >>= reply)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
handleServerStream ::
(Service s, HasMethod s m)
=> RPC s m
-> ServerStreamHandler s m a
-> WaiHandler
handleServerStream rpc handler decoding encoding req write flush = do
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (\i -> handler req i >>= replyN)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
replyN (v, sStream) = do
let go v1 = serverStreamNext sStream v1 >>= \case
Just (v2, msg) -> do
write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go v2
Nothing -> pure ()
go v
handleClientStream ::
(Service s, HasMethod s m)
=> RPC s m
-> ClientStreamHandler s m a
-> WaiHandler
handleClientStream rpc handler0 decoding encoding req write flush = do
handler0 req >>= go
where
go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) (handleMsg v) (handleEof v) nextChunk
where
nextChunk = requestBody req
handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
loop chunk v1 = handleRequestChunksLoop (flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding)) (handleMsg v1) (handleEof v1) nextChunk
handleBiDiStream ::
(Service s, HasMethod s m)
=> RPC s m
-> BiDiStreamHandler s m a
-> WaiHandler
handleBiDiStream rpc handler0 decoding encoding req write flush = do
handler0 req >>= go ""
where
nextChunk = requestBody req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go chunk (v0, bStream) = do
let cont dat v1 = go dat (v1, bStream)
step <- (bidirNextStep bStream) v0
case step of
WaitInput handleMsg handleEof -> do
handleRequestChunksLoop (flip pushChunk chunk $ decodeInput rpc $ _getDecodingCompression decoding)
(\dat msg -> handleMsg v0 msg >>= cont dat)
(handleEof v0 >>= cont "")
nextChunk
WriteOutput v1 msg -> do
reply msg
cont "" v1
Abort -> return ()
type GeneralStreamHandler s m a b =
Request -> IO (a, IncomingStream s m a, b, OutgoingStream s m b)
data IncomingStream s m a = IncomingStream {
incomingStreamHandler :: a -> MethodInput s m -> IO a
, incomingStreamFinalizer :: a -> IO ()
}
data OutgoingStream s m a = OutgoingStream {
outgoingStreamNext :: a -> IO (Maybe (a, MethodOutput s m))
}
handleGeneralStream ::
(Service s, HasMethod s m)
=> RPC s m
-> GeneralStreamHandler s m a b
-> WaiHandler
handleGeneralStream rpc handler0 decoding encoding req write flush = void $ do
handler0 req >>= go
where
newDecoder = decodeInput rpc $ _getDecodingCompression decoding
nextChunk = requestBody req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go (in0, instream, out0, outstream) = concurrently
(incomingLoop newDecoder in0 instream)
(replyLoop out0 outstream)
replyLoop v0 sstream@(OutgoingStream next) = do
next v0 >>= \case
Nothing -> return v0
(Just (v1, msg)) -> reply msg >> replyLoop v1 sstream
incomingLoop decode v0 cstream = do
let handleMsg dat msg = do
v1 <- incomingStreamHandler cstream v0 msg
incomingLoop (pushChunk newDecoder dat) v1 cstream
let handleEof = incomingStreamFinalizer cstream v0 >> pure v0
handleRequestChunksLoop decode handleMsg handleEof nextChunk
handleRequestChunksLoop
:: (Message a)
=> Decoder (Either String a)
-> (ByteString -> a -> IO b)
-> IO b
-> IO ByteString
-> IO b
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
case decoder of
(Done unusedDat _ (Right val)) -> do
handleMsg unusedDat val
(Done _ _ (Left err)) -> do
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
(Fail _ _ err) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
partial@(Partial _) -> do
chunk <- nextChunk
if ByteString.null chunk
then
handleEof
else
handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk
errorOnLeftOver :: (a -> IO b) -> ByteString -> a -> IO b
errorOnLeftOver f rest
| ByteString.null rest = f
| otherwise = const $ closeEarly $ GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest)