{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Network.GRPC.Server.Handlers where
import Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Lazy (toStrict)
import Data.ProtoLens.Message (Message)
import Data.ProtoLens.Service.Types (Service(..), HasMethod, HasMethodImpl(..), StreamingType(..))
import Network.GRPC.HTTP2.Encoding (decodeInput, encodeOutput, Encoding(..), Decoding(..))
import Network.GRPC.HTTP2.Types (RPC(..), GRPCStatus(..), GRPCStatusCode(..), path)
import Network.Wai (Request, requestBody, strictRequestBody)
import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)
type UnaryHandler s m = Request -> MethodInput s m -> IO (MethodOutput s m)
type ServerStreamHandler s m a = Request -> MethodInput s m -> IO (a, ServerStream s m a)
newtype ServerStream s m a = ServerStream {
serverStreamNext :: a -> IO (Maybe (a, MethodOutput s m))
}
type ClientStreamHandler s m a = Request -> IO (a, ClientStream s m a)
data ClientStream s m a = ClientStream {
clientStreamHandler :: a -> MethodInput s m -> IO a
, clientStreamFinalizer :: a -> IO (MethodOutput s m)
}
unary
:: (Service s, HasMethod s m)
=> RPC s m
-> UnaryHandler s m
-> ServiceHandler
unary rpc handler =
ServiceHandler (path rpc) (handleUnary rpc handler)
serverStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ ServerStreaming)
=> RPC s m
-> ServerStreamHandler s m a
-> ServiceHandler
serverStream rpc handler =
ServiceHandler (path rpc) (handleServerStream rpc handler)
clientStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ ClientStreaming)
=> RPC s m
-> ClientStreamHandler s m a
-> ServiceHandler
clientStream rpc handler =
ServiceHandler (path rpc) (handleClientStream rpc handler)
handleUnary ::
(Service s, HasMethod s m)
=> RPC s m
-> UnaryHandler s m
-> WaiHandler
handleUnary rpc handler decoding encoding req write flush = do
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (\i -> handler req i >>= reply)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
handleServerStream ::
(Service s, HasMethod s m)
=> RPC s m
-> ServerStreamHandler s m a
-> WaiHandler
handleServerStream rpc handler decoding encoding req write flush = do
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (\i -> handler req i >>= replyN)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
replyN (v, sStream) = do
let go v1 = serverStreamNext sStream v1 >>= \case
Just (v2, msg) -> do
write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go v2
Nothing -> pure ()
go v
handleClientStream ::
(Service s, HasMethod s m)
=> RPC s m
-> ClientStreamHandler s m a
-> WaiHandler
handleClientStream rpc handler0 decoding encoding req write flush = do
handler0 req >>= go
where
go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) (handleMsg v) (handleEof v) nextChunk
where
nextChunk = requestBody req
handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
loop chunk v1 = handleRequestChunksLoop (flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding)) (handleMsg v1) (handleEof v1) nextChunk
handleRequestChunksLoop
:: (Message a)
=> Decoder (Either String a)
-> (ByteString -> a -> IO ())
-> IO ()
-> IO ByteString
-> IO ()
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
case decoder of
(Done unusedDat _ (Right val)) -> do
handleMsg unusedDat val
(Done _ _ (Left err)) -> do
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
(Fail _ _ err) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
partial@(Partial _) -> do
chunk <- nextChunk
if ByteString.null chunk
then
handleEof
else
handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk
errorOnLeftOver :: (a -> IO b) -> ByteString -> a -> IO b
errorOnLeftOver f rest
| ByteString.null rest = f
| otherwise = const $ closeEarly $ GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest)