{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Network.GRPC.Server.Handlers.Trans where
import Control.Concurrent.Async (concurrently)
import Control.Monad (void, (>=>))
import Control.Monad.IO.Class
import Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Lazy (toStrict)
import Network.GRPC.HTTP2.Encoding
import Network.GRPC.HTTP2.Types (path, GRPCStatus(..), GRPCStatusCode(..))
#if MIN_VERSION_wai(3,2,2)
import Network.Wai (Request, getRequestBodyChunk, strictRequestBody)
#else
import Network.Wai (Request, requestBody, strictRequestBody)
#endif
#if MIN_VERSION_base(4,11,0)
#else
import Data.Monoid ((<>))
#endif
import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)
#if !MIN_VERSION_wai(3,2,2)
getRequestBodyChunk :: Request -> IO ByteString
getRequestBodyChunk = requestBody
#endif
type UnaryHandler m i o = Request -> i -> m o
type ServerStreamHandler m i o a = Request -> i -> m (a, ServerStream m o a)
newtype ServerStream m o a = ServerStream {
serverStreamNext :: a -> m (Maybe (a, o))
}
type ClientStreamHandler m i o a = Request -> m (a, ClientStream m i o a)
data ClientStream m i o a = ClientStream {
clientStreamHandler :: a -> i -> m a
, clientStreamFinalizer :: a -> m o
}
type BiDiStreamHandler m i o a = Request -> m (a, BiDiStream m i o a)
data BiDiStep m i o a
= Abort
| WaitInput !(a -> i -> m a) !(a -> m a)
| WriteOutput !a o
newtype BiDiStream m i o a = BiDiStream {
bidirNextStep :: a -> m (BiDiStep m i o a)
}
unary
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> UnaryHandler m i o
-> ServiceHandler
unary f rpc handler =
ServiceHandler (path rpc) (handleUnary f rpc handler)
serverStream
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> ServerStreamHandler m i o a
-> ServiceHandler
serverStream f rpc handler =
ServiceHandler (path rpc) (handleServerStream f rpc handler)
clientStream
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> ClientStreamHandler m i o a
-> ServiceHandler
clientStream f rpc handler =
ServiceHandler (path rpc) (handleClientStream f rpc handler)
bidiStream
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> BiDiStreamHandler m i o a
-> ServiceHandler
bidiStream f rpc handler =
ServiceHandler (path rpc) (handleBiDiStream f rpc handler)
generalStream
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> GeneralStreamHandler m i o a b
-> ServiceHandler
generalStream f rpc handler =
ServiceHandler (path rpc) (handleGeneralStream f rpc handler)
handleUnary
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> UnaryHandler m i o
-> WaiHandler
handleUnary f rpc handler decoding encoding req write flush = f $
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (handler req >=> liftIO . reply)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
handleServerStream
:: (MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> ServerStreamHandler m i o a
-> WaiHandler
handleServerStream f rpc handler decoding encoding req write flush = f $
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (handler req >=> replyN)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
replyN (v, sStream) = do
let go v1 = serverStreamNext sStream v1 >>= \case
Just (v2, msg) -> do
liftIO $ write (encodeOutput rpc (_getEncodingCompression encoding) msg)
liftIO flush
go v2
Nothing -> pure ()
go v
handleClientStream
:: forall m r i o a.
(MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> ClientStreamHandler m i o a
-> WaiHandler
handleClientStream f rpc handler0 decoding encoding req write flush =
f $ handler0 req >>= go
where
go :: (a, ClientStream m i o a) -> m ()
go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
(handleMsg v) (handleEof v) nextChunk
where
nextChunk = getRequestBodyChunk req
handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
reply msg = do
liftIO $ write (encodeOutput rpc (_getEncodingCompression encoding) msg)
liftIO flush
loop chunk v1 = handleRequestChunksLoop
(flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding))
(handleMsg v1) (handleEof v1) nextChunk
handleBiDiStream
:: forall m r i o a.
(MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> BiDiStreamHandler m i o a
-> WaiHandler
handleBiDiStream f rpc handler0 decoding encoding req write flush =
f $ handler0 req >>= go ""
where
nextChunk = getRequestBodyChunk req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go :: ByteString -> (a, BiDiStream m i o a) -> m ()
go chunk (v0, bStream) = do
let cont dat v1 = go dat (v1, bStream)
step <- bidirNextStep bStream v0
case step of
WaitInput handleMsg handleEof ->
handleRequestChunksLoop (flip pushChunk chunk
$ decodeInput rpc
$ _getDecodingCompression decoding)
(\dat msg -> handleMsg v0 msg >>= cont dat)
(handleEof v0 >>= cont "")
nextChunk
WriteOutput v1 msg -> do
liftIO $ reply msg
cont "" v1
Abort -> return ()
type GeneralStreamHandler m i o a b =
Request -> m (a, IncomingStream m i a, b, OutgoingStream m o b)
data IncomingStream m i a = IncomingStream {
incomingStreamHandler :: a -> i -> m a
, incomingStreamFinalizer :: a -> m ()
}
newtype OutgoingStream m o a = OutgoingStream {
outgoingStreamNext :: a -> m (Maybe (a, o))
}
handleGeneralStream
:: forall m r i o a b.
(MonadIO m, GRPCInput r i, GRPCOutput r o)
=> (forall x. m x -> IO x)
-> r
-> GeneralStreamHandler m i o a b
-> WaiHandler
handleGeneralStream f rpc handler0 decoding encoding req write flush = void $
f $ handler0 req >>= go
where
newDecoder = decodeInput rpc $ _getDecodingCompression decoding
nextChunk = getRequestBodyChunk req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go :: (a, IncomingStream m i a, b, OutgoingStream m o b) -> m (a, b)
go (in0, instream, out0, outstream) = liftIO $ concurrently
(f $ incomingLoop newDecoder in0 instream)
(f $ replyLoop out0 outstream)
replyLoop :: b -> OutgoingStream m o b -> m b
replyLoop v0 sstream@(OutgoingStream next) =
next v0 >>= \case
Nothing -> return v0
(Just (v1, msg)) -> liftIO (reply msg) >> replyLoop v1 sstream
incomingLoop :: Decoder (Either String i) -> a -> IncomingStream m i a -> m a
incomingLoop decode v0 cstream = do
let handleMsg dat msg = do
v1 <- incomingStreamHandler cstream v0 msg
incomingLoop (pushChunk newDecoder dat) v1 cstream
let handleEof = incomingStreamFinalizer cstream v0 >> pure v0
handleRequestChunksLoop decode handleMsg handleEof nextChunk
handleRequestChunksLoop
:: (MonadIO m)
=> Decoder (Either String a)
-> (ByteString -> a -> m b)
-> m b
-> IO ByteString
-> m b
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
case decoder of
(Done unusedDat _ (Right val)) ->
handleMsg unusedDat val
(Done _ _ (Left err)) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
(Fail _ _ err) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
partial@(Partial _) -> do
chunk <- liftIO nextChunk
if ByteString.null chunk
then
handleEof
else
handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk
errorOnLeftOver :: MonadIO m => (a -> m b) -> ByteString -> a -> m b
errorOnLeftOver f rest
| ByteString.null rest = f
| otherwise = const $ do
liftIO (putStrLn "left-over")
closeEarly (GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest))