{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Vector.Algorithms.Intro
(
sort
, sortBy
, sortByBounds
, select
, selectBy
, selectByBounds
, partialSort
, partialSortBy
, partialSortByBounds
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison, midPoint)
import qualified Data.Vector.Algorithms.Insertion as I
import qualified Data.Vector.Algorithms.Optimal as O
import qualified Data.Vector.Algorithms.Heap as H
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp a = sortByBounds cmp a 0 (length a)
{-# INLINE sortBy #-}
sortByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> m ()
sortByBounds cmp a l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| otherwise = introsort cmp a (ilg len) l u
where len = u - l
{-# INLINE sortByBounds #-}
introsort :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
introsort cmp a i l u = sort i l u >> I.sortByBounds cmp a l u
where
sort 0 l u = H.sortByBounds cmp a l u
sort d l u
| len < threshold = return ()
| otherwise = do O.sort3ByIndex cmp a c l (u-1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid - 1)
sort (d-1) mid u
sort (d-1) l (mid - 1)
where
len = u - l
c = midPoint u l
{-# INLINE introsort #-}
select
:: (PrimMonad m, MVector v e, Ord e)
=> v (PrimState m) e
-> Int
-> m ()
select = selectBy compare
{-# INLINE select #-}
selectBy
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> m ()
selectBy cmp a k = selectByBounds cmp a k 0 (length a)
{-# INLINE selectBy #-}
selectByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
selectByBounds cmp a k l u
| l >= u = return ()
| otherwise = go (ilg len) l (l + k) u
where
len = u - l
go 0 l m u = H.selectByBounds cmp a (m - l) l u
go n l m u = do O.sort3ByIndex cmp a c l (u-1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid - 1)
if m > mid
then go (n-1) mid m u
else if m < mid - 1
then go (n-1) l m (mid - 1)
else return ()
where c = midPoint u l
{-# INLINE selectByBounds #-}
partialSort
:: (PrimMonad m, MVector v e, Ord e)
=> v (PrimState m) e
-> Int
-> m ()
partialSort = partialSortBy compare
{-# INLINE partialSort #-}
partialSortBy
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> m ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
{-# INLINE partialSortBy #-}
partialSortByBounds
:: (PrimMonad m, MVector v e)
=> Comparison e
-> v (PrimState m) e
-> Int
-> Int
-> Int
-> m ()
partialSortByBounds cmp a k l u
| l >= u = return ()
| otherwise = go (ilg len) l (l + k) u
where
isort = introsort cmp a
{-# INLINE [1] isort #-}
len = u - l
go 0 l m n = H.partialSortByBounds cmp a (m - l) l u
go n l m u
| l == m = return ()
| otherwise = do O.sort3ByIndex cmp a c l (u-1)
p <- unsafeRead a l
mid <- partitionBy cmp a p (l+1) u
unsafeSwap a l (mid - 1)
case compare m mid of
GT -> do isort (n-1) l (mid - 1)
go (n-1) mid m u
EQ -> isort (n-1) l m
LT -> go n l m (mid - 1)
where c = midPoint u l
{-# INLINE partialSortByBounds #-}
partitionBy :: forall m v e. (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> e -> Int -> Int -> m Int
partitionBy cmp a = partUp
where
partUp :: e -> Int -> Int -> m Int
partUp p l u
| l < u = do e <- unsafeRead a l
case cmp e p of
LT -> partUp p (l+1) u
_ -> partDown p l (u-1)
| otherwise = return l
partDown :: e -> Int -> Int -> m Int
partDown p l u
| l < u = do e <- unsafeRead a u
case cmp p e of
LT -> partDown p l (u-1)
_ -> unsafeSwap a l u >> partUp p (l+1) u
| otherwise = return l
{-# INLINE partitionBy #-}
ilg :: Int -> Int
ilg m = 2 * loop m 0
where
loop 0 !k = k - 1
loop n !k = loop (n `shiftR` 1) (k+1)
threshold :: Int
threshold = 18
{-# INLINE threshold #-}