module Control.Concurrent.CachedIO (
    cachedIO,
    cachedIOWith,
    cachedIO',
    cachedIOWith'
    ) where

import Control.Concurrent.STM (atomically, newTVar, readTVar, writeTVar, retry, TVar)
import Control.Monad (join)
import Control.Monad.Catch (MonadCatch, onException)
import Control.Monad.IO.Class (liftIO, MonadIO)
import Data.Time.Clock (NominalDiffTime, addUTCTime, getCurrentTime, UTCTime)

data State a  = Uninitialized | Initializing | Updating a | Fresh UTCTime a

-- | Cache an IO action, producing a version of this IO action that is cached
-- for 'interval' seconds. The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh, if the cache is older than 'interval'
-- seconds.
cachedIO :: (MonadIO m, MonadIO t, MonadCatch t)
         => NominalDiffTime -- ^ Number of seconds before refreshing cache
         -> t a             -- ^ IO action to cache
         -> m (t a)
cachedIO :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
NominalDiffTime -> t a -> m (t a)
cachedIO NominalDiffTime
interval = forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool) -> t a -> m (t a)
cachedIOWith (NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval)

-- | Cache an IO action, producing a version of this IO action that is cached
-- for 'interval' seconds. The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh, if the cache is older than 'interval'
-- seconds.
cachedIO' :: (MonadIO m, MonadIO t, MonadCatch t)
          => NominalDiffTime -- ^ Number of seconds before refreshing cache
          -> (Maybe (UTCTime, a) -> t a) -- ^ action to cache. The stale value and its refresh date
          -- are passed so that the action can perform external staleness checks
          -> m (t a)
cachedIO' :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
NominalDiffTime -> (Maybe (UTCTime, a) -> t a) -> m (t a)
cachedIO' NominalDiffTime
interval = forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (t a)
cachedIOWith' (NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval)

-- | Check if @starting time@ + @seconds@ is after @end time@
secondsPassed :: NominalDiffTime  -- ^ Seconds
               -> UTCTime         -- ^ Start time
               -> UTCTime         -- ^ End time
               -> Bool
secondsPassed :: NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval UTCTime
start UTCTime
end = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
interval UTCTime
start forall a. Ord a => a -> a -> Bool
> UTCTime
end

-- | Cache an IO action, The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh
cachedIOWith
    :: (MonadIO m, MonadIO t, MonadCatch t)
    => (UTCTime -> UTCTime -> Bool) -- ^ Test function:
    --   If 'isCacheStillFresh' 'lastUpdated' 'now' returns 'True'
    --   the cache is considered still fresh and returns the cached IO action
    -> t a -- ^ action to cache.
    -> m (t a)
cachedIOWith :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool) -> t a -> m (t a)
cachedIOWith UTCTime -> UTCTime -> Bool
f t a
io = forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (t a)
cachedIOWith' UTCTime -> UTCTime -> Bool
f (forall a b. a -> b -> a
const t a
io)

-- | Cache an IO action, The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh
cachedIOWith'
    :: (MonadIO m, MonadIO t, MonadCatch t)
    => (UTCTime -> UTCTime -> Bool) -- ^ Test function:
    --   If 'isCacheStillFresh' 'lastUpdated' 'now' returns 'True'
    --   the cache is considered still fresh and returns the cached IO action
    -> (Maybe (UTCTime, a) -> t a) -- ^ action to cache. The stale value and its refresh date
    -- are passed so that the action can perform external staleness checks
    -> m (t a)
cachedIOWith' :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (t a)
cachedIOWith' UTCTime -> UTCTime -> Bool
isCacheStillFresh Maybe (UTCTime, a) -> t a
io = do
  TVar (State a)
cachedT <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. STM a -> IO a
atomically (forall a. a -> STM (TVar a)
newTVar forall a. State a
Uninitialized))
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ do
    UTCTime
now <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
      State a
cached <- forall a. TVar a -> STM a
readTVar TVar (State a)
cachedT
      case State a
cached of
        previousState :: State a
previousState@(Fresh UTCTime
lastUpdated a
value)
        -- There's data in the cache and it's recent. Just return.
          | UTCTime -> UTCTime -> Bool
isCacheStillFresh UTCTime
lastUpdated UTCTime
now -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return a
value)
        -- There's data in the cache, but it's stale. Update the cache state
        -- to prevent a second thread from also executing the action. The second
        -- thread will get the stale data instead.
          | Bool
otherwise -> do
            forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT (forall a. a -> State a
Updating a
value)
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ State a -> TVar (State a) -> t a
refreshCache State a
previousState TVar (State a)
cachedT
        -- Another thread is already updating the cache, just return the stale value
        Updating a
value -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return a
value)
        -- The cache is uninitialized. Mark the cache as initializing to block other
        -- threads. Initialize and return.
        State a
Uninitialized -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ State a -> TVar (State a) -> t a
refreshCache forall a. State a
Uninitialized TVar (State a)
cachedT
        -- The cache is uninitialized and another thread is already attempting to
        -- initialize it. Block.
        State a
Initializing -> forall a. STM a
retry
  where
    refreshCache :: State a -> TVar (State a) -> t a
refreshCache State a
previousState TVar (State a)
cachedT = do
      let previous :: Maybe (UTCTime, a)
previous = case State a
previousState of
            Fresh UTCTime
lastUpdated a
value -> forall a. a -> Maybe a
Just (UTCTime
lastUpdated, a
value)
            State a
_                       -> forall a. Maybe a
Nothing
      a
newValue <- Maybe (UTCTime, a) -> t a
io Maybe (UTCTime, a)
previous forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` forall (m :: * -> *) a.
MonadIO m =>
State a -> TVar (State a) -> m ()
restoreState State a
previousState TVar (State a)
cachedT
      UTCTime
now <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. STM a -> IO a
atomically (forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT (forall a. UTCTime -> a -> State a
Fresh UTCTime
now a
newValue)))
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall (m :: * -> *) a. Monad m => a -> m a
return a
newValue)

restoreState :: (MonadIO m) => State a -> TVar (State a) -> m ()
restoreState :: forall (m :: * -> *) a.
MonadIO m =>
State a -> TVar (State a) -> m ()
restoreState State a
previousState TVar (State a)
cachedT = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. STM a -> IO a
atomically (forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT State a
previousState))