{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Vector.Algorithms.Merge
( sort
, sortBy
, Comparison
) where
import Prelude hiding (read, length)
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison, copyOffset, midPoint)
import qualified Data.Vector.Algorithms.Optimal as O
import qualified Data.Vector.Algorithms.Insertion as I
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp vec = if len <= 4
then if len <= 2
then if len /= 2
then return ()
else O.sort2ByOffset cmp vec 0
else if len == 3
then O.sort3ByOffset cmp vec 0
else O.sort4ByOffset cmp vec 0
else if len < threshold
then I.sortByBounds cmp vec 0 len
else do buf <- new halfLen
mergeSortWithBuf cmp vec buf
where
len = length vec
halfLen = (len + 1) `div` 2
{-# INLINE sortBy #-}
mergeSortWithBuf :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
mergeSortWithBuf cmp src buf = loop 0 (length src)
where
loop l u
| len < threshold = I.sortByBounds cmp src l u
| otherwise = do loop l mid
loop mid u
merge cmp (unsafeSlice l len src) buf (mid - l)
where len = u - l
mid = midPoint u l
{-# INLINE mergeSortWithBuf #-}
merge :: (PrimMonad m, MVector v e)
=> Comparison e -> v (PrimState m) e -> v (PrimState m) e
-> Int -> m ()
merge cmp src buf mid = do unsafeCopy tmp lower
eTmp <- unsafeRead tmp 0
eUpp <- unsafeRead upper 0
loop tmp 0 eTmp upper 0 eUpp 0
where
lower = unsafeSlice 0 mid src
upper = unsafeSlice mid (length src - mid) src
tmp = unsafeSlice 0 mid buf
wroteHigh low iLow eLow high iHigh iIns
| iHigh >= length high = unsafeCopy (unsafeSlice iIns (length low - iLow) src)
(unsafeSlice iLow (length low - iLow) low)
| otherwise = do eHigh <- unsafeRead high iHigh
loop low iLow eLow high iHigh eHigh iIns
wroteLow low iLow high iHigh eHigh iIns
| iLow >= length low = return ()
| otherwise = do eLow <- unsafeRead low iLow
loop low iLow eLow high iHigh eHigh iIns
loop !low !iLow !eLow !high !iHigh !eHigh !iIns = case cmp eHigh eLow of
LT -> do unsafeWrite src iIns eHigh
wroteHigh low iLow eLow high (iHigh + 1) (iIns + 1)
_ -> do unsafeWrite src iIns eLow
wroteLow low (iLow + 1) high iHigh eHigh (iIns + 1)
{-# INLINE merge #-}
threshold :: Int
threshold = 25
{-# INLINE threshold #-}