module Data.Vector.Algorithms.Intro
(
sort
, sortBy
, sortByBounds
, select
, selectBy
, selectByBounds
, partialSort
, partialSortBy
, partialSortByBounds
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison)
import qualified Data.Vector.Algorithms.Insertion as I
import qualified Data.Vector.Algorithms.Optimal as O
import qualified Data.Vector.Algorithms.Heap as H
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp a = sortByBounds cmp a 0 (length a)
sortByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
sortByBounds cmp a l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| otherwise = introsort cmp a (ilg len) l u
where len = u l
introsort :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
introsort cmp a i l u = sort i l u >> I.sortByBounds cmp a l u
where
sort 0 l u = H.sortByBounds cmp a l u
sort d l u
| len < threshold = return ()
| otherwise = do O.sort3ByIndex cmp a c l (u1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid 1)
sort (d1) mid u
sort (d1) l (mid 1)
where
len = u l
c = (u + l) `div` 2
select :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
select = selectBy compare
selectBy :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> m ()
selectBy cmp a k = selectByBounds cmp a k 0 (length a)
selectByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
selectByBounds cmp a k l u
| l >= u = return ()
| otherwise = go (ilg len) l (l + k) u
where
len = u l
go 0 l m u = H.selectByBounds cmp a (m l) l u
go n l m u = do O.sort3ByIndex cmp a c l (u1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid 1)
if m > mid
then go (n1) mid m u
else if m < mid 1
then go (n1) l m (mid 1)
else return ()
where c = (u + l) `div` 2
partialSort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
partialSort = partialSortBy compare
partialSortBy :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> m ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
partialSortByBounds :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
partialSortByBounds cmp a k l u
| l >= u = return ()
| otherwise = go (ilg len) l (l + k) u
where
isort = introsort cmp a
len = u l
go 0 l m n = H.partialSortByBounds cmp a (m l) l u
go n l m u
| l == m = return ()
| otherwise = do O.sort3ByIndex cmp a c l (u1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid 1)
case compare m mid of
GT -> do isort (n1) l (mid 1)
go (n1) mid m u
EQ -> isort (n1) l m
LT -> go n l m (mid 1)
where c = (u + l) `div` 2
partitionBy :: forall m v e. (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> e -> Int -> Int -> m Int
partitionBy cmp a = partUp
where
partUp :: e -> Int -> Int -> m Int
partUp p l u
| l < u = do e <- unsafeRead a l
case cmp e p of
LT -> partUp p (l+1) u
_ -> partDown p l (u1)
| otherwise = return l
partDown :: e -> Int -> Int -> m Int
partDown p l u
| l < u = do e <- unsafeRead a u
case cmp p e of
LT -> partDown p l (u1)
_ -> unsafeSwap a l u >> partUp p (l+1) u
| otherwise = return l
ilg :: Int -> Int
ilg m = 2 * loop m 0
where
loop 0 !k = k 1
loop n !k = loop (n `shiftR` 1) (k+1)
threshold :: Int
threshold = 18