module Data.Vector.Algorithms.Heap
(
sort
, sortBy
, sortByBounds
, select
, selectBy
, selectByBounds
, partialSort
, partialSortBy
, partialSortByBounds
, heapify
, pop
, popTo
, sortHeap
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison)
import qualified Data.Vector.Algorithms.Optimal as O
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp a = sortByBounds cmp a 0 (length a)
sortByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
sortByBounds cmp a l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| otherwise = heapify cmp a l u >> sortHeap cmp a l (l+4) u >> O.sort4ByOffset cmp a l
where len = u l
select :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
select = selectBy compare
selectBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> Int -> m ()
selectBy cmp a k = selectByBounds cmp a k 0 (length a)
selectByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
selectByBounds cmp a k l u
| l + k <= u = heapify cmp a l (l + k) >> go l (l + k) (u 1)
| otherwise = return ()
where
go l m u
| u < m = return ()
| otherwise = do el <- unsafeRead a l
eu <- unsafeRead a u
case cmp eu el of
LT -> popTo cmp a l m u
_ -> return ()
go l m (u 1)
partialSort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
partialSort = partialSortBy compare
partialSortBy :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> m ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
partialSortByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
partialSortByBounds cmp a k l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| u <= l + k = sortByBounds cmp a l u
| otherwise = do selectByBounds cmp a k l u
sortHeap cmp a l (l + 4) (l + k)
O.sort4ByOffset cmp a l
where
len = u l
heapify :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
heapify cmp a l u = loop $ (len 1) `shiftR` 2
where
len = u l
loop k
| k < 0 = return ()
| otherwise = unsafeRead a (l+k) >>= \e ->
siftByOffset cmp a e l k len >> loop (k 1)
pop :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
pop cmp a l u = popTo cmp a l u u
popTo :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
popTo cmp a l u t = do al <- unsafeRead a l
at <- unsafeRead a t
unsafeWrite a t al
siftByOffset cmp a at l 0 (u l)
sortHeap :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
sortHeap cmp a l m u = loop (u1) >> unsafeSwap a l m
where
loop k
| m < k = pop cmp a l k >> loop (k1)
| otherwise = return ()
siftByOffset :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> e -> Int -> Int -> Int -> m ()
siftByOffset cmp a val off start len = sift val start len
where
sift val root len
| child < len = do (child', ac) <- maximumChild cmp a off child len
case cmp val ac of
LT -> unsafeWrite a (root + off) ac >> sift val child' len
_ -> unsafeWrite a (root + off) val
| otherwise = unsafeWrite a (root + off) val
where child = root `shiftL` 2 + 1
maximumChild :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m (Int, e)
maximumChild cmp a off child1 len
| child4 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
ac3 <- unsafeRead a (child3 + off)
ac4 <- unsafeRead a (child4 + off)
return $ case cmp ac1 ac2 of
LT -> case cmp ac2 ac3 of
LT -> case cmp ac3 ac4 of
LT -> (child4, ac4)
_ -> (child3, ac3)
_ -> case cmp ac2 ac4 of
LT -> (child4, ac4)
_ -> (child2, ac2)
_ -> case cmp ac1 ac3 of
LT -> case cmp ac3 ac4 of
LT -> (child4, ac4)
_ -> (child3, ac3)
_ -> case cmp ac1 ac4 of
LT -> (child4, ac4)
_ -> (child1, ac1)
| child3 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
ac3 <- unsafeRead a (child3 + off)
return $ case cmp ac1 ac2 of
LT -> case cmp ac2 ac3 of
LT -> (child3, ac3)
_ -> (child2, ac2)
_ -> case cmp ac1 ac3 of
LT -> (child3, ac3)
_ -> (child1, ac1)
| child2 < len = do ac1 <- unsafeRead a (child1 + off)
ac2 <- unsafeRead a (child2 + off)
return $ case cmp ac1 ac2 of
LT -> (child2, ac2)
_ -> (child1, ac1)
| otherwise = do ac1 <- unsafeRead a (child1 + off) ; return (child1, ac1)
where
child2 = child1 + 1
child3 = child1 + 2
child4 = child1 + 3