{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Vector.Algorithms.Heap
(
sort
, sortBy
, sortByBounds
, select
, selectBy
, selectByBounds
, partialSort
, partialSortBy
, partialSortByBounds
, heapify
, pop
, popTo
, sortHeap
, heapInsert
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison)
import qualified Data.Vector.Algorithms.Optimal as O
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp a = sortByBounds cmp a 0 (length a)
{-# INLINE sortBy #-}
sortByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> m ()
sortByBounds cmp a l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| otherwise = heapify cmp a l u >> sortHeap cmp a l (l+4) u >> O.sort4ByOffset cmp a l
where len = u - l
{-# INLINE sortByBounds #-}
select
:: (PrimMonad m, MVector v e, Ord e)
=> v (PrimState m) e
-> Int
-> m ()
select = selectBy compare
{-# INLINE select #-}
selectBy
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> m ()
selectBy cmp a k = selectByBounds cmp a k 0 (length a)
{-# INLINE selectBy #-}
selectByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
selectByBounds cmp a k l u
| l + k <= u = heapify cmp a l (l + k) >> go l (l + k) (u - 1)
| otherwise = return ()
where
go l m u
| u < m = return ()
| otherwise = do el <- unsafeRead a l
eu <- unsafeRead a u
case cmp eu el of
LT -> popTo cmp a l m u
_ -> return ()
go l m (u - 1)
{-# INLINE selectByBounds #-}
partialSort
:: (PrimMonad m, MVector v e, Ord e)
=> v (PrimState m) e
-> Int
-> m ()
partialSort = partialSortBy compare
{-# INLINE partialSort #-}
partialSortBy
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> m ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
{-# INLINE partialSortBy #-}
partialSortByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
partialSortByBounds cmp a k l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| u <= l + k = sortByBounds cmp a l u
| otherwise = do selectByBounds cmp a k l u
sortHeap cmp a l (l + 4) (l + k)
O.sort4ByOffset cmp a l
where
len = u - l
{-# INLINE partialSortByBounds #-}
heapify
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> m ()
heapify cmp a l u = loop $ (len - 1) `shiftR` 2
where
len = u - l
loop k
| k < 0 = return ()
| otherwise = unsafeRead a (l+k) >>= \e ->
siftByOffset cmp a e l k len >> loop (k - 1)
{-# INLINE heapify #-}
pop
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> m ()
pop cmp a l u = popTo cmp a l u u
{-# INLINE pop #-}
popTo
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
popTo cmp a l u t = do al <- unsafeRead a l
at <- unsafeRead a t
unsafeWrite a t al
siftByOffset cmp a at l 0 (u - l)
{-# INLINE popTo #-}
sortHeap
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
sortHeap cmp a l m u = loop (u-1) >> unsafeSwap a l m
where
loop k
| m < k = pop cmp a l k >> loop (k-1)
| otherwise = return ()
{-# INLINE sortHeap #-}
heapInsert
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> e
-> m ()
heapInsert cmp v l u e = sift (u - l)
where
sift k
| k <= 0 = unsafeWrite v l e
| otherwise = let pi = shiftR (k-1) 2
in unsafeRead v (l + pi) >>= \p -> case cmp p e of
LT -> unsafeWrite v (l + k) p >> sift pi
_ -> unsafeWrite v (l + k) e
{-# INLINE heapInsert #-}
siftByOffset :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> e -> Int -> Int -> Int -> m ()
siftByOffset cmp a val off start len = sift val start len
where
sift val root len
| child < len = do (child', ac) <- maximumChild cmp a off child len
case cmp val ac of
LT -> unsafeWrite a (root + off) ac >> sift val child' len
_ -> unsafeWrite a (root + off) val
| otherwise = unsafeWrite a (root + off) val
where child = root `shiftL` 2 + 1
{-# INLINE siftByOffset #-}
maximumChild :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m (Int, e)
maximumChild cmp a off child1 len
| child4 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
ac3 <- unsafeRead a (child3 + off)
ac4 <- unsafeRead a (child4 + off)
return $ case cmp ac1 ac2 of
LT -> case cmp ac2 ac3 of
LT -> case cmp ac3 ac4 of
LT -> (child4, ac4)
_ -> (child3, ac3)
_ -> case cmp ac2 ac4 of
LT -> (child4, ac4)
_ -> (child2, ac2)
_ -> case cmp ac1 ac3 of
LT -> case cmp ac3 ac4 of
LT -> (child4, ac4)
_ -> (child3, ac3)
_ -> case cmp ac1 ac4 of
LT -> (child4, ac4)
_ -> (child1, ac1)
| child3 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
ac3 <- unsafeRead a (child3 + off)
return $ case cmp ac1 ac2 of
LT -> case cmp ac2 ac3 of
LT -> (child3, ac3)
_ -> (child2, ac2)
_ -> case cmp ac1 ac3 of
LT -> (child3, ac3)
_ -> (child1, ac1)
| child2 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
return $ case cmp ac1 ac2 of
LT -> (child2, ac2)
_ -> (child1, ac1)
| otherwise = do ac1 <- unsafeRead a (child1 + off) ; return (child1, ac1)
where
child2 = child1 + 1
child3 = child1 + 2
child4 = child1 + 3
{-# INLINE maximumChild #-}