{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Data.Sort.Merge (
sort,
sortBy,
) where
import Data.Array.Accelerate
import Data.Array.Accelerate.Unsafe
sort :: Ord a => Acc (Vector a) -> Acc (Vector a)
sort = sortBy compare
sortBy :: Elt a => (Exp a -> Exp a -> Exp Ordering) -> Acc (Vector a) -> Acc (Vector a)
sortBy cmp input = output
where
n = length input
T2 _ output = awhile condition step
(T2 (unit insertion_segment_size) (insertion_sort cmp n input))
condition (T2 blockSize _) = map (< length input) blockSize
step (T2 blockSize' values) = T2 (unit $ blockSize * 2) values'
where
blockSize = the blockSize'
newIndices = imap (newIndex values cmp n blockSize) values
values' = scatter newIndices (fill (index1 n) undef) values
insertion_segment_size :: Exp Int
insertion_segment_size = 32
insertion_sort
:: Elt a
=> (Exp a -> Exp a -> Exp Ordering)
-> Exp Int
-> Acc (Vector a)
-> Acc (Vector a)
insertion_sort cmp n xs = scatter indices (fill (I1 n) undef) xs
where
indices = imap f xs
f (I1 ix) x = segment_start + offset
where
segment = ix `quot` insertion_segment_size
segment_start = segment * insertion_segment_size
segment_end = ((segment + 1) * insertion_segment_size) `min` n
T2 _ offset = while
(\(T2 i _) -> i < segment_end)
(\(T2 i c) ->
let x' = xs !! i
smaller = let d = cmp x' x
in d == LT_ || d == EQ_ && i > ix
in
T2 (i + 1) (c + (smaller ? (1, 0))))
(T2 segment_start 0)
newIndex
:: Elt a
=> Acc (Vector a)
-> (Exp a -> Exp a -> Exp Ordering)
-> Exp Int
-> Exp Int
-> Exp DIM1
-> Exp a
-> Exp Int
newIndex values cmp valueCount blockSize (I1 index) value = index + offset
where
blockIndex = index `quot` blockSize
left = even blockIndex
otherBlockIndex = blockIndex + (left ? (1, -1))
searchMinIndex = otherBlockIndex * blockSize
searchMaxIndex = min valueCount $ (otherBlockIndex + 1) * blockSize
countOtherBlock = binarySearch values cmp value (not left) searchMinIndex searchMaxIndex
offset = countOtherBlock - (left ? (0, blockSize))
binarySearch
:: Elt a
=> Acc (Vector a)
-> (Exp a -> Exp a -> Exp Ordering)
-> Exp a
-> Exp Bool
-> Exp Int
-> Exp Int
-> Exp Int
binarySearch values cmp query inclusive initialMinIndex initialMaxIndex =
index - initialMinIndex
where
T2 _ index = while
(\(T2 i j) -> i + 1 < j)
(\(T2 i j) ->
let m = (i + j) `quot` 2
a_m = values !! m
det = let c = a_m `cmp` query
in c == LT_ || c == EQ_ && inclusive
in
det ? (T2 m j, T2 i m))
(T2 (initialMinIndex - 1) initialMaxIndex)