{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
module Entwine.Async (
    AsyncTimeout (..)
  , renderAsyncTimeout
  , waitWithTimeout
  , waitEitherBoth
  ) where


import           Control.Concurrent.Async (Async, waitSTM, waitEither)
import           Control.Concurrent.Async (async, cancel, wait, AsyncCancelled)
import           Control.Concurrent.STM (atomically, orElse, retry)
import           Control.Monad.Catch (catches, throwM, Handler(..))
import           Control.Monad.IO.Class (liftIO)
import           Control.Monad.Trans.Either

import           Data.IORef (newIORef, readIORef, writeIORef)
import qualified Data.Text as T

import           Entwine.P
import           Entwine.Snooze

import           System.IO (IO)

data AsyncTimeout =
  AsyncTimeout Duration
  deriving (Eq, Show)

renderAsyncTimeout :: AsyncTimeout -> Text
renderAsyncTimeout e =
  case e of
    AsyncTimeout d ->
      mconcat [
          "Async took greater than '"
        , T.pack . show $ toSeconds d
        , " seconds' to return."
        ]

waitWithTimeout :: Async a -> Duration -> EitherT AsyncTimeout IO a
waitWithTimeout a d = do
  r <- liftIO $ newIORef False
  s <- liftIO . async $ snooze d
  e <- liftIO $ waitEither a s
  case e of
    Left a' ->
      pure $ a'
    Right _ -> do
      liftIO $ writeIORef r True
      liftIO $ cancel a
      (liftIO $ wait a)
        `catches` [ Handler (\ (ax :: AsyncCancelled) -> do
                                liftIO (readIORef r) >>=
                                  bool (liftIO $ throwM ax) (left $ AsyncTimeout d)
                            )]

waitEitherBoth :: Async a -> Async b -> Async c -> IO (Either a (b, c))
waitEitherBoth a b c =
  atomically $ do
    let
      l = waitSTM a
      r = do
        bb <- waitSTM b `orElse` (waitSTM c >> retry)
        cc <- waitSTM c
        return (bb, cc)
    fmap Left l `orElse` fmap Right r