{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Data.Sort.Quick (
sort,
sortBy,
) where
import Data.Array.Accelerate
import Data.Array.Accelerate.Unsafe
import Data.Array.Accelerate.Data.Bits
import Data.Array.Accelerate.Data.Maybe
sort :: Ord a => Acc (Vector a) -> Acc (Vector a)
sort = sortBy compare
sortBy :: Elt a => (Exp a -> Exp a -> Exp Ordering) -> Acc (Vector a) -> Acc (Vector a)
sortBy cmp input = result
where
initialFlags = scatter (fill (I1 1) 0 ++ fill (I1 1) (length input)) emptyFlags fullFlags
emptyFlags = fill (I1 (1 + length input)) False_
fullFlags = fill (I1 2) True_
T2 result _ = awhile condition (step cmp) $ T2 input initialFlags
type State a =
( Vector a
, Vector Bool
)
step :: Elt a => (Exp a -> Exp a -> Exp Ordering) -> Acc (State a) -> Acc (State a)
step cmp (T2 values headFlags) = (T2 values' headFlags')
where
pivots = propagateSegmentHead headFlags values
isLarger = zipWith (\v p -> cmp v p /= LT_) values pivots
startIndex = propagateSegmentHead headFlags (generate (shape values) unindex1)
indicesLarger, indicesSmaller :: Acc (Vector Int)
indicesLarger = map (\x -> x - 1) $ postscanSegHead (+) headFlags $ map (? (1, 0)) isLarger
indicesSmaller = map (\x -> x - 1) $ postscanSegHead (+) headFlags $ map (? (0, 1)) isLarger
countSmaller :: Acc (Vector Int)
countSmaller = map (+1) $ propagateSegmentLast headFlags indicesSmaller
permutation = zipWith5 partitionPermuteIndex isLarger startIndex indicesSmaller indicesLarger countSmaller
values' = scatter permutation (fill (shape values) undef) values
headFlags' =
let
f :: Int -> Exp Bool -> Exp Int -> Exp Int -> Exp (Maybe DIM1)
f inc headF start countSmall =
if headF
then Just_ (I1 (start + countSmall + constant inc))
else Nothing_
writes :: Int -> Acc (Vector (Maybe DIM1))
writes inc = zipWith3 (f inc) headFlags startIndex countSmaller
in
writeFlags (writes 0) $ writeFlags (writes 1) $ headFlags
condition :: Elt a => Acc (State a) -> Acc (Scalar Bool)
condition (T2 _ headFlags) = map not $ fold (&&) True_ headFlags
partitionPermuteIndex :: Exp Bool -> Exp Int -> Exp Int -> Exp Int -> Exp Int -> Exp Int
partitionPermuteIndex isLarger start indexIfSmaller indexIfLarger countSmaller =
start + (isLarger ? (countSmaller + indexIfLarger, indexIfSmaller))
propagateSegmentHead
:: Elt a
=> Acc (Vector Bool)
-> Acc (Vector a)
-> Acc (Vector a)
propagateSegmentHead headFlags values
= map fst
$ postscanl f (T2 undef True_)
$ zip values headFlags
where
f left (T2 rightValue rightFlag) =
if rightFlag
then T2 rightValue True_
else left
propagateSegmentLast
:: Elt a
=> Acc (Vector Bool)
-> Acc (Vector a)
-> Acc (Vector a)
propagateSegmentLast headFlags values
= map fst
$ postscanr f (T2 undef True_)
$ zip values
$ tail headFlags
where
f (T2 leftValue leftFlag) right =
if leftFlag
then T2 leftValue True_
else right
postscanSegHead
:: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector Bool)
-> Acc (Vector a)
-> Acc (Vector a)
postscanSegHead f headFlags values
= map fst
$ postscanl g (T2 undef True_)
$ zip values headFlags
where
g (T2 leftValue leftFlag) (T2 rightValue rightFlag)
= T2
(rightFlag ? (rightValue, f leftValue rightValue))
(leftFlag .|. rightFlag)
writeFlags
:: Acc (Vector (Maybe DIM1))
-> Acc (Vector Bool)
-> Acc (Vector Bool)
writeFlags writes flags =
permute const flags (writes !) (fill (shape writes) True_)