{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Network.GRPC.Server.Handlers where

import           Control.Concurrent.Async (concurrently)
import           Control.Monad (void)
import           Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import           Data.ByteString.Char8 (ByteString)
import           Data.ByteString.Lazy (toStrict)
import           Data.ProtoLens.Message (Message)
import           Data.ProtoLens.Service.Types (Service(..), HasMethod, HasMethodImpl(..), StreamingType(..))
import           Network.GRPC.HTTP2.Encoding (decodeInput, encodeOutput, Encoding(..), Decoding(..))
import           Network.GRPC.HTTP2.Types (RPC(..), GRPCStatus(..), GRPCStatusCode(..), path)
import           Network.Wai (Request, requestBody, strictRequestBody)

import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)

-- | Handy type to refer to Handler for 'unary' RPCs handler.
type UnaryHandler s m = Request -> MethodInput s m -> IO (MethodOutput s m)

-- | Handy type for 'server-streaming' RPCs.
--
-- We expect an implementation to:
-- - read the input request
-- - return an initial state and an state-passing action that the server code will call to fetch the output to send to the client (or close an a Nothing)
-- See 'ServerStream' for the type which embodies these requirements.
type ServerStreamHandler s m a = Request -> MethodInput s m -> IO (a, ServerStream s m a)

newtype ServerStream s m a = ServerStream {
    serverStreamNext :: a -> IO (Maybe (a, MethodOutput s m))
  }

-- | Handy type for 'client-streaming' RPCs.
--
-- We expect an implementation to:
-- - acknowledge a the new client stream by returning an initial state and two functions:
-- - a state-passing handler for new client message
-- - a state-aware handler for answering the client when it is ending its stream
-- See 'ClientStream' for the type which embodies these requirements.
type ClientStreamHandler s m a = Request -> IO (a, ClientStream s m a)

data ClientStream s m a = ClientStream {
    clientStreamHandler   :: a -> MethodInput s m -> IO a
  , clientStreamFinalizer :: a -> IO (MethodOutput s m)
  }

-- | Handy type for 'bidirectional-streaming' RPCs.
--
-- We expect an implementation to:
-- - acknowlege a new bidirection stream by returning an initial state and one functions:
-- - a state-passing function that returns a single action step
-- The action may be to
-- - stop immediately
-- - wait and handle some input with a callback and a finalizer (if the client closes the stream on its side) that may change the state
-- - return a value and a new state
--
-- There is no way to stop locally (that would mean sending HTTP2 trailers) and
-- keep receiving messages from the client.
type BiDiStreamHandler s m a = Request -> IO (a, BiDiStream s m a)

data BiDiStep s m a
  = Abort
  | WaitInput !(a -> MethodInput s m -> IO a) !(a -> IO a)
  | WriteOutput !a (MethodOutput s m)

data BiDiStream s m a = BiDiStream {
    bidirNextStep :: a -> IO (BiDiStep s m a)
  }

-- | Construct a handler for handling a unary RPC.
unary
  :: (Service s, HasMethod s m)
  => RPC s m
  -> UnaryHandler s m
  -> ServiceHandler
unary rpc handler =
    ServiceHandler (path rpc) (handleUnary rpc handler)

-- | Construct a handler for handling a server-streaming RPC.
serverStream
  :: (Service s, HasMethod s m,  MethodStreamingType s m ~ 'ServerStreaming)
  => RPC s m
  -> ServerStreamHandler s m a
  -> ServiceHandler
serverStream rpc handler =
    ServiceHandler (path rpc) (handleServerStream rpc handler)

-- | Construct a handler for handling a client-streaming RPC.
clientStream
  :: (Service s, HasMethod s m,  MethodStreamingType s m ~ 'ClientStreaming)
  => RPC s m
  -> ClientStreamHandler s m a
  -> ServiceHandler
clientStream rpc handler =
    ServiceHandler (path rpc) (handleClientStream rpc handler)

-- | Construct a handler for handling a bidirectional-streaming RPC.
bidiStream
  :: (Service s, HasMethod s m,  MethodStreamingType s m ~ 'BiDiStreaming)
  => RPC s m
  -> BiDiStreamHandler s m a
  -> ServiceHandler
bidiStream rpc handler =
    ServiceHandler (path rpc) (handleBiDiStream rpc handler)

-- | Construct a handler for handling a bidirectional-streaming RPC.
generalStream
  :: (Service s, HasMethod s m)
  => RPC s m
  -> GeneralStreamHandler s m a b
  -> ServiceHandler
generalStream rpc handler =
    ServiceHandler (path rpc) (handleGeneralStream rpc handler)

-- | Handle unary RPCs.
handleUnary ::
     (Service s, HasMethod s m)
  => RPC s m
  -> UnaryHandler s m
  -> WaiHandler
handleUnary rpc handler decoding encoding req write flush = do
    handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
  where
    nextChunk = toStrict <$> strictRequestBody req
    handleMsg = errorOnLeftOver (\i -> handler req i >>= reply)
    handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush

-- | Handle Server-Streaming RPCs.
handleServerStream ::
     (Service s, HasMethod s m)
  => RPC s m
  -> ServerStreamHandler s m a
  -> WaiHandler
handleServerStream rpc handler decoding encoding req write flush = do
    handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) handleMsg handleEof nextChunk
  where
    nextChunk = toStrict <$> strictRequestBody req
    handleMsg = errorOnLeftOver (\i -> handler req i >>= replyN)
    handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
    replyN (v, sStream) = do
        let go v1 = serverStreamNext sStream v1 >>= \case
                Just (v2, msg) -> do
                    write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
                    go v2
                Nothing -> pure ()
        go v

-- | Handle Client-Streaming RPCs.
handleClientStream ::
     (Service s, HasMethod s m)
  => RPC s m
  -> ClientStreamHandler s m a
  -> WaiHandler
handleClientStream rpc handler0 decoding encoding req write flush = do
    handler0 req >>= go
  where
    go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding) (handleMsg v) (handleEof v) nextChunk
      where
        nextChunk = requestBody req
        handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
        handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
        reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
        loop chunk v1 = handleRequestChunksLoop (flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding)) (handleMsg v1) (handleEof v1) nextChunk

-- | Handle Bidirectional-Streaming RPCs.
handleBiDiStream ::
    (Service s, HasMethod s m)
  => RPC s m
  -> BiDiStreamHandler s m a
  -> WaiHandler
handleBiDiStream rpc handler0 decoding encoding req write flush = do
    handler0 req >>= go ""
  where
    nextChunk = requestBody req
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
    go chunk (v0, bStream) = do
        let cont dat v1 = go dat (v1, bStream)
        step <- (bidirNextStep bStream) v0
        case step of
            WaitInput handleMsg handleEof -> do
                handleRequestChunksLoop (flip pushChunk chunk $ decodeInput rpc $ _getDecodingCompression decoding)
                                        (\dat msg -> handleMsg v0 msg >>= cont dat)
                                        (handleEof v0 >>= cont "")
                                        nextChunk
            WriteOutput v1 msg -> do
                reply msg
                cont "" v1
            Abort -> return ()

-- | A GeneralStreamHandler combining server and client asynchronous streams.
type GeneralStreamHandler s m a b =
    Request -> IO (a, IncomingStream s m a, b, OutgoingStream s m b)

-- | Pair of handlers for reacting to incoming messages.
data IncomingStream s m a = IncomingStream {
    incomingStreamHandler   :: a -> MethodInput s m -> IO a
  , incomingStreamFinalizer :: a -> IO ()
  }

-- | Handler to decide on the next message (if any) to return.
data OutgoingStream s m a = OutgoingStream {
    outgoingStreamNext  :: a -> IO (Maybe (a, MethodOutput s m))
  }

-- | Handler for the somewhat general case where two threads behave concurrently:
-- - one reads messages from the client
-- - one returns messages to the client
handleGeneralStream ::
    (Service s, HasMethod s m)
  => RPC s m
  -> GeneralStreamHandler s m a b
  -> WaiHandler
handleGeneralStream rpc handler0 decoding encoding req write flush = void $ do
    handler0 req >>= go
  where
    newDecoder = decodeInput rpc $ _getDecodingCompression decoding
    nextChunk = requestBody req
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush

    go (in0, instream, out0, outstream) = concurrently
        (incomingLoop newDecoder in0 instream)
        (replyLoop out0 outstream)

    replyLoop v0 sstream@(OutgoingStream next) = do
        next v0 >>= \case
            Nothing          -> return v0
            (Just (v1, msg)) -> reply msg >> replyLoop v1 sstream

    incomingLoop decode v0 cstream = do
        let handleMsg dat msg = do
                v1 <- incomingStreamHandler cstream v0 msg
                incomingLoop (pushChunk newDecoder dat) v1 cstream
        let handleEof = incomingStreamFinalizer cstream v0 >> pure v0
        handleRequestChunksLoop decode handleMsg handleEof nextChunk


-- | Helpers to consume input in chunks.
handleRequestChunksLoop
  :: (Message a)
  => Decoder (Either String a)
  -- ^ Message decoder.
  -> (ByteString -> a -> IO b)
  -- ^ Handler for a single message.
  -- The ByteString corresponds to leftover data.
  -> IO b
  -- ^ Handler for handling end-of-streams.
  -> IO ByteString
  -- ^ Action to retrieve the next chunk.
  -> IO b
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
    case decoder of
        (Done unusedDat _ (Right val)) -> do
            handleMsg unusedDat val
        (Done _ _ (Left err)) -> do
            closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
        (Fail _ _ err)         ->
            closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
        partial@(Partial _)    -> do
            chunk <- nextChunk
            if ByteString.null chunk
            then
                handleEof
            else
                handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk

-- | Combinator around message handler to error on left overs.
--
-- This combinator ensures that, unless for client stream, an unparsed piece of
-- data with a correctly-read message is treated as an error.
errorOnLeftOver :: (a -> IO b) -> ByteString -> a -> IO b
errorOnLeftOver f rest
  | ByteString.null rest = f
  | otherwise            = const $ closeEarly $ GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest)