-- |This module can execute events at specified time.  It uses a two thread 
-- system that allows the STM adding and deleting of new threads without
-- requiring later IO actions.  For a simpler system that uses relative times
-- see Control.Event.Relative
module Control.Event (
         EventId
        ,EventSystem
        ,noEvent
        ,initEventSystem
        ,addEvent
        ,addEventSTM
        ,cancelEvent
        ,cancelEventSTM
        ,evtSystemSize
        ) where

import Prelude hiding (lookup, catch)
import Control.Concurrent (forkIO, myThreadId, ThreadId, threadDelay)
import Control.Concurrent.STM
import Control.Exception
import Control.Monad (forever, when)
import Data.Dynamic
import Data.List (partition, deleteBy)
import Data.Map (Map, empty, findMin, deleteFindMin, insertLookupWithKey, adjust, size, singleton, toList, insert, updateLookupWithKey, delete, lookup, fold)
import Data.Time

type EventNumber = Int
type EventSet = (EventNumber, Map EventNumber (IO ()))
singletonSet :: (IO ()) -> EventSet
singletonSet a = (1, singleton 0 a)

-- |IDs useful for canceling previously scheduled events.
data EventId = EvtId UTCTime EventNumber deriving (Eq, Ord, Show)

-- |A value indicating there is no such event.
-- Canceling this event returns True and has no other effect.
noEvent :: EventId
noEvent = EvtId never (-1)

never :: UTCTime
never = UTCTime (toEnum (-1)) (-1)

-- |The event system must be initilized using initEventSystem.
-- More than one event system can be instantiated at once
-- (eg. for non-interference).
data EventSystem = EvtSys {
    esEvents :: TVar (Map UTCTime EventSet),      -- Pending Events
    esThread :: TVar (Maybe ThreadId),              -- Id of thread for TimerReset exceptions
    esAlarm  :: TVar UTCTime,                     -- Time of soonest event
    esNewAlarm :: TVar Bool,                        -- An event w/ earlier expiration was added
    esExpired  :: TVar [[EventSet]]
    }

-- |The only way to get an event system is to initilize one

-- This sets internal TVars and sparks three threads:
--    trackAlarm: Ensured esAlarm remains 
--    monitorExpiredQueue: Executes all events on the expired queue
--    expireEvents: When the alarm sounds, moves expired events to expired queue
initEventSystem :: IO EventSystem
initEventSystem = do
    evts <- newTVarIO empty
    tid  <- newTVarIO Nothing
    alm  <- newTVarIO never
    new  <- newTVarIO False
    exp  <- newTVarIO []
    let evtSys = EvtSys evts tid alm new exp
    forkIO $ forever $ trackAlarm evtSys
    forkIO $ forever $ monitorExpiredQueue exp
    forkIO $ expireEvents evtSys
    return evtSys

-- |Main thread. Delays till the alarm time then executes any expired events.
-- Asynchronous 'TimerReset' exceptions might occur to indicate a new, 
-- earlier, alarm time.
expireEvents :: EventSystem -> IO ()
expireEvents es = mask $ \restore -> do
     tid <- myThreadId
     forever $ catch (restore (setTID (Just tid) es >> expireEvents' es))
                     (\TimerReset -> return ())
  where
  setTID i es = atomically (writeTVar (esThread es) i)

-- |Worker function for expireEvents - the parent simply catches the exceptions
expireEvents' :: EventSystem -> IO ()
expireEvents' evtSys = do
    usDelay <- determineDelay
    threadDelay usDelay
    runExpire evtSys
  where
  determineDelay :: IO Int
  determineDelay = do
    alm <- atomically (do
             evts <- readTVar (esEvents evtSys)
             case findMinM evts of
               Nothing    -> retry
               Just (c,_) -> return c )
    now <- getCurrentTime
    return $ timeDiffToMicroSec $ diffUTCTime alm now

  findMinM :: Map UTCTime EventSet -> Maybe (UTCTime,EventSet)
  findMinM m | size m == 0 = Nothing
             | otherwise   = Just $ findMin m

-- |Determines which events are expired, running all their actions
runExpire :: EventSystem -> IO ()
runExpire evtSys = do
    now  <- getCurrentTime
    atomically (do  evts <- readTVar (esEvents evtSys)
                    let (exp, newMap) = getEarlierKeys now evts
                        newAlarm = getAlarm newMap
                    writeTVar (esAlarm evtSys) newAlarm
                    writeTVar (esEvents evtSys) newMap
                    exps <- readTVar (esExpired evtSys)
                    writeTVar (esExpired evtSys) (exp:exps) )
  where
  getEarlierKeys :: UTCTime -> Map UTCTime EventSet -> ([EventSet], Map UTCTime EventSet)
  getEarlierKeys clk m =
      case deleteFindMinM m of
          Just ((k,es), m') ->
              if k < clk
                  then let (exp, lastMap) = getEarlierKeys clk m'
                       in (es:exp, lastMap)
                  else ([], m)
          Nothing -> ([], m)

  getAlarm m | size m == 0 = never
             | otherwise   = fst $ findMin m

  deleteFindMinM :: Map k a -> Maybe ((k, a), Map k a)
  deleteFindMinM m = if size m == 0 then Nothing else Just (deleteFindMin m)

-- |Execute expired events
monitorExpiredQueue :: TVar [[EventSet]] -> IO ()
monitorExpiredQueue exp = do
    exp <- atomically (do
               e <- readTVar exp
               case e of
                   (a:as) -> writeTVar exp [] >> return e
                   _      -> retry )
    mapM_ (mapM_ runEvents) exp

-- |Runs all provided events (which must have expired)
runEvents :: EventSet -> IO ()
runEvents (_,set) = do
    let actions = map snd (toList set)
    mapM_ forkIO actions

-- |Add an *action* to be performed at *time* by *system*.  Returns a unique ID.
addEvent :: EventSystem -> UTCTime -> IO () -> IO EventId
addEvent sys clk act = atomically (addEventSTM sys clk act)

-- |Atomic version of addEvent
addEventSTM :: EventSystem -> UTCTime -> IO () -> STM EventId
addEventSTM sys clk act = do
    evts <- readTVar (esEvents sys)
    let (old, newMap) = insertLookupWithKey (\_ _ o -> insertEvent o) clk (singletonSet act) evts
        num = case old of
                 Nothing    -> 0
                 Just (n,_) -> n
        eid = EvtId clk num
    writeTVar (esEvents sys) newMap
    alm <- readTVar (esAlarm sys)
    when (clk < alm || alm == never)
         (writeTVar (esAlarm sys) clk >> writeTVar (esNewAlarm sys) True)
    return eid
  where
  insertEvent :: EventSet -> EventSet
  insertEvent (num,set) | num == maxBound = error "maxBound events at given time, something is broken."
                        | otherwise =
      (num+1, insert num act set)

-- |Cancel an event from the system, returning True on success.
cancelEvent :: EventSystem -> EventId -> IO Bool
cancelEvent sys eid = atomically (cancelEventSTM sys eid)

-- |Atomic version of cancelEvent
cancelEventSTM :: EventSystem -> EventId -> STM Bool
cancelEventSTM sys eid@(EvtId clk num) = do
    evts <- readTVar (esEvents sys)
    let newMap :: Map UTCTime EventSet
        prev :: Maybe EventSet
        (prev,newMap) = insertLookupWithKey (\_ _ (cnt, old) -> (cnt,delete num old)) clk undefined evts
        ret = case prev of
                Nothing -> False -- error "Canceling an event that never existed."
                Just (_,p)  -> case lookup clk newMap of
                                        Nothing -> False
                                        Just (_,m)  -> (size p /= size m)
    when (eid /= noEvent) (writeTVar (esEvents sys) newMap)
    return (eid == noEvent || ret)

-- |Returns the number of pending events.
evtSystemSize :: EventSystem -> STM Int
evtSystemSize sys = do
    evts <- readTVar (esEvents sys)
    return $ foldr (\(_,m) n -> n + size m) 0 evts

-- |Tracks the alarm time and the earliest event.  If an earlier event is added
-- the alarm time is updated and TimerReset is thrown to the expireEvent thread
trackAlarm :: EventSystem -> IO ()
trackAlarm sys = do
    tid <- atomically (do
               newAlm <- readTVar (esNewAlarm sys)
               if newAlm then writeTVar (esNewAlarm sys) False else retry

               tid <- readTVar (esThread sys)
               i <- case tid of
                        Just i  -> return i
                        Nothing -> retry
               return i )
    throwTo tid TimerReset

-- |Returns the time difference in microseconds (potentially returning maxBound <= the real difference)
timeDiffToMicroSec :: NominalDiffTime -> Int
timeDiffToMicroSec = floor . (* 10^6)

data TimerReset = TimerReset deriving (Eq, Ord, Show, Typeable)

instance Exception TimerReset