-- |
-- Module      : Streamly.Internal.Data.Fold.Concurrent.Channel
-- Copyright   : (c) 2022 Composewell Technologies
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC

module Streamly.Internal.Data.Fold.Concurrent.Channel
    (
    module Streamly.Internal.Data.Fold.Concurrent.Channel.Type

    -- * Configuration
    , maxBuffer
    , boundThreads
    , inspect

    -- * Fold operations
    , parEval
    )
where

import Control.Concurrent (takeMVar)
import Control.Monad (void, when)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.IORef (writeIORef)
import Streamly.Internal.Control.Concurrent (MonadAsync)
import Streamly.Internal.Data.Fold (Fold(..), Step (..))
import Streamly.Internal.Data.Channel.Worker (sendWithDoorBell)
import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime)

import Streamly.Internal.Data.Fold.Concurrent.Channel.Type
import Streamly.Internal.Data.Channel.Types

-------------------------------------------------------------------------------
-- Evaluating a Fold
-------------------------------------------------------------------------------

-- XXX Cleanup the fold if the stream is interrupted. Add a GC hook.

-- | Evaluate a fold asynchronously using a concurrent channel. The driver just
-- queues the input stream values to the fold channel buffer and returns. The
-- fold evaluates the queued values asynchronously. On finalization, 'parEval'
-- waits for the asynchronous fold to complete before it returns.
--
{-# INLINABLE parEval #-}
parEval :: MonadAsync m => (Config -> Config) -> Fold m a b -> Fold m a b
parEval :: forall (m :: * -> *) a b.
MonadAsync m =>
(Config -> Config) -> Fold m a b -> Fold m a b
parEval Config -> Config
modifier Fold m a b
f =
    (Channel m a b -> a -> m (Step (Channel m a b) b))
-> m (Step (Channel m a b) b)
-> (Channel m a b -> m b)
-> (Channel m a b -> m b)
-> Fold m a b
forall (m :: * -> *) a b s.
(s -> a -> m (Step s b))
-> m (Step s b) -> (s -> m b) -> (s -> m b) -> Fold m a b
Fold Channel m a b -> a -> m (Step (Channel m a b) b)
forall {m :: * -> *} {a} {b}.
(MonadIO m, MonadBaseControl IO m, MonadThrow m) =>
Channel m a b -> a -> m (Step (Channel m a b) b)
step m (Step (Channel m a b) b)
forall {b}. m (Step (Channel m a b) b)
initial Channel m a b -> m b
forall {p} {a}. p -> a
extract Channel m a b -> m b
forall {m :: * -> *} {a} {b}.
(MonadIO m, MonadBaseControl IO m, MonadThrow m) =>
Channel m a b -> m b
final

    where

    initial :: m (Step (Channel m a b) b)
initial = Channel m a b -> Step (Channel m a b) b
forall s b. s -> Step s b
Partial (Channel m a b -> Step (Channel m a b) b)
-> m (Channel m a b) -> m (Step (Channel m a b) b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Config -> Config) -> Fold m a b -> m (Channel m a b)
forall (m :: * -> *) a b.
MonadRunInIO m =>
(Config -> Config) -> Fold m a b -> m (Channel m a b)
newChannel Config -> Config
modifier Fold m a b
f

    -- XXX This is not truly asynchronous. If the fold is done we only get to
    -- know when we send the next input unless the stream ends. We could
    -- potentially throw an async exception to the driver to inform it
    -- asynchronously. Alternatively, the stream should not block forever, it
    -- should keep polling the fold status. We can insert a timer tick in the
    -- input stream to do that.
    --
    -- A polled stream abstraction may be useful, it would consist of normal
    -- events and tick events, latter are guaranteed to arrive.
    --
    -- XXX We can use the config to indicate if the fold is a scanning type or
    -- one-shot, or use a separate parEvalScan for scanning. For a scanning
    -- type fold the worker would always send the intermediate values back to
    -- the driver. An intermediate value can be returned on an input, or the
    -- driver can poll even without input, if we have the Skip input support.
    -- When the buffer is full we can return "Skip" and then the next step
    -- without input can wait for an output to arrive. Similarly, when "final"
    -- is called it can return "Skip" to continue or "Done" to indicate
    -- termination.
    step :: Channel m a b -> a -> m (Step (Channel m a b) b)
step Channel m a b
chan a
a = do
        Maybe b
status <- Channel m a b -> a -> m (Maybe b)
forall (m :: * -> *) a b.
MonadAsync m =>
Channel m a b -> a -> m (Maybe b)
sendToWorker Channel m a b
chan a
a
        Step (Channel m a b) b -> m (Step (Channel m a b) b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (Channel m a b) b -> m (Step (Channel m a b) b))
-> Step (Channel m a b) b -> m (Step (Channel m a b) b)
forall a b. (a -> b) -> a -> b
$ case Maybe b
status of
            Maybe b
Nothing -> Channel m a b -> Step (Channel m a b) b
forall s b. s -> Step s b
Partial Channel m a b
chan
            Just b
b -> b -> Step (Channel m a b) b
forall s b. b -> Step s b
Done b
b

    -- XXX We can use a separate type for non-scanning folds that will
    -- introduce a lot of complexity. Are there combinators that rely on the
    -- "extract" function even in non-scanning use cases?
    -- Instead of making such folds partial we can also make them return a
    -- Maybe type.
    extract :: p -> a
extract p
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Concurrent folds do not support scanning"

    -- XXX depending on the use case we may want to either wait for the result
    -- or cancel the ongoing work. We can use the config to control that?
    -- Currently it waits for the work to complete.
    final :: Channel m a b -> m b
final Channel m a b
chan = do
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void
            (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef ([ChildEvent a], Int) -> MVar () -> ChildEvent a -> IO Int
forall a.
IORef ([ChildEvent a], Int) -> MVar () -> ChildEvent a -> IO Int
sendWithDoorBell
                (Channel m a b -> IORef ([ChildEvent a], Int)
forall (m :: * -> *) a b.
Channel m a b -> IORef ([ChildEvent a], Int)
outputQueue Channel m a b
chan)
                (Channel m a b -> MVar ()
forall (m :: * -> *) a b. Channel m a b -> MVar ()
outputDoorBell Channel m a b
chan)
                ChildEvent a
forall a. ChildEvent a
ChildStopChannel
        Maybe b
status <- Channel m a b -> m (Maybe b)
forall (m :: * -> *) a b.
MonadAsync m =>
Channel m a b -> m (Maybe b)
checkFoldStatus Channel m a b
chan
        case Maybe b
status of
            Maybe b
Nothing -> do
                IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                    (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Bool -> IO [Char] -> [Char] -> IO () -> IO ()
withDiagMVar
                        (Channel m a b -> Bool
forall (m :: * -> *) a b. Channel m a b -> Bool
svarInspectMode Channel m a b
chan)
                        (Channel m a b -> IO [Char]
forall (m :: * -> *) a b. Channel m a b -> IO [Char]
dumpSVar Channel m a b
chan)
                        [Char]
"parEval: waiting to drain"
                    (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar (Channel m a b -> MVar ()
forall (m :: * -> *) a b. Channel m a b -> MVar ()
outputDoorBellFromConsumer Channel m a b
chan)
                -- XXX remove recursion
                Channel m a b -> m b
final Channel m a b
chan
            Just b
b -> do
                Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Channel m a b -> Bool
forall (m :: * -> *) a b. Channel m a b -> Bool
svarInspectMode Channel m a b
chan) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                    AbsTime
t <- Clock -> IO AbsTime
getTime Clock
Monotonic
                    IORef (Maybe AbsTime) -> Maybe AbsTime -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef (Maybe AbsTime)
svarStopTime (Channel m a b -> SVarStats
forall (m :: * -> *) a b. Channel m a b -> SVarStats
svarStats Channel m a b
chan)) (AbsTime -> Maybe AbsTime
forall a. a -> Maybe a
Just AbsTime
t)
                    IO [Char] -> [Char] -> IO ()
printSVar (Channel m a b -> IO [Char]
forall (m :: * -> *) a b. Channel m a b -> IO [Char]
dumpSVar Channel m a b
chan) [Char]
"SVar Done"
                b -> m b
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return b
b