{-# LANGUAGE Rank2Types, MultiParamTypeClasses, FlexibleInstances, GeneralizedNewtypeDeriving, TypeFamilies, UndecidableInstances #-}

{- | Safe implementation of an array-backed binary heap.  The 'HeapT' transformer requires that the underlying monad provide a 'MonadST' instance, meaning that the bottom-level monad must be 'ST'.  This critical restriction protects referential transparency, disallowing multi-threaded behavior as if the '[]' monad were at the bottom level. (The 'HeapM' monad takes care of the 'ST' bottom level automatically.)
-}
module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMOn, runHeapT, runHeapTOn, UHeapT, runUHeapT) where

import Control.Monad.Array.ArrayT
import Control.Monad.Array.Unboxed
import Control.Monad.Array.Class

import Control.Monad.ST
import Control.Monad.ST.Class

import Data.Array.Vector
import Control.Monad.State.Strict
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class

import Control.Monad

-- | Monad based on an array implementation of a standard binary heap.
type HeapM s e = HeapT e (ST s)
-- | Monad transformer based on an array implementation of a standard binary heap.
newtype HeapT e m a = HeapT {execHeapT :: StateT Int (ArrayT e m) a} deriving (Monad, MonadPlus, MonadFix, MonadReader r, MonadWriter w)
newtype UHeapT e m a = UHeapT {execUHeapT :: StateT Int (UArrayT e m) a} deriving (Monad, MonadFix, MonadReader r, MonadWriter w)

instance MonadTrans (HeapT e) where
	lift = HeapT . lift . lift

instance MonadState s m => MonadState s (HeapT e m) where
	get = lift get
	put = lift . put

-- | Runs an 'HeapM' computation starting with an empty heap.
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runST $ runHeapT m

runHeapMOn :: Ord e => (forall s . HeapM s e a) -> Int -> [e] -> a
runHeapMOn m n l = runST $ runHeapTOn m n l

runHeapT :: (MonadST m, Monad m) => HeapT e m a -> m a
runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0)

runUHeapT :: (MonadST m, Monad m, UA e, Ord e) => UHeapT e m a -> m a
runUHeapT m = evalUArrayT 16 (evalStateT (execUHeapT m) 0)

-- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list.  (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.)
runHeapTOn :: (MonadST m, Monad m, Ord e) => 
				HeapT e m a -- ^ The transformer operation.
 				-> Int -- ^ The starting size of the heap (must be equal to the length of the list)
 				-> [e] -- ^ The initial contents of the heap
 				-> m a
runHeapTOn m n l = runArrayT_ n $ flip evalStateT n $ do	mapM_ (uncurry unsafeWriteAt) (zip [0..n-1] l)
								mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [n-1,n-2..0]
								execHeapT m

instance (MonadST m, Monad m, Ord e) => MonadQueue (HeapT e m) where
	type QKey (HeapT e m) = e
	queuePeek = HeapT $ do	
		size <- get
		if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
	queueInsert x = HeapT $ do
		size <- get
		ensureHeap (size+1)
		put (size + 1)
		heapUp size x
	queueDelete = HeapT $ do
		size <- get
		put (size - 1)
		unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined
	queueSize = HeapT get

instance (MonadST m, Monad m, UA e, Ord e) => MonadQueue (UHeapT e m) where
	type QKey (UHeapT e m) = e
	queuePeek = UHeapT $ do	
		size <- get
		if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
	queueInsert x = UHeapT $ do
		size <- get
		ensureHeap (size+1)
		put (size + 1)
		heapUp size x
	queueDelete = UHeapT $ do
		size <- get
		put (size - 1)
		unsafeReadAt (size - 1) >>= heapDown (size - 1) 0
	queueSize = UHeapT get

{-# INLINE ensureHeap #-}
ensureHeap :: MonadArray m => Int -> m ()
ensureHeap n = do	cap <- askSize
			when (n - 1 >= cap) (resize (4 * n))

{-# INLINE heapUp #-}
heapUp :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> e -> m ()
heapUp = let	heapUp' 0 x	= unsafeWriteAt 0 x
		heapUp' i x	= let j = (i - 1) `quot` 2 in do
			aj <- unsafeReadAt j
			if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
		in heapUp'

{-# INLINE heapDown #-}
heapDown :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
	where	heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
			LT	-> do	al <- unsafeReadAt lch
					ar <- unsafeReadAt rch
					let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
					if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
			EQ	-> do	al <- readAt lch
					if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
			GT	-> unsafeWriteAt i x