{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Data.Tree.Radix
where
import Data.Array.Accelerate
import Data.Array.Accelerate.Unsafe
import Data.Array.Accelerate.Data.Bits
import Data.Array.Accelerate.Data.Maybe
import qualified Data.Bits as P
import qualified Prelude as P
data Node = Node !Word8
!Ptr
!Ptr
!Int
deriving (Show, Generic, Elt)
pattern Node_ :: Exp Word8 -> Exp Ptr -> Exp Ptr -> Exp Int -> Exp Node
pattern Node_ b l r p = Pattern (b, l, r, p)
{-# COMPLETE Node_ #-}
newtype Ptr = Ptr Int
deriving (Generic, Elt)
instance Show Ptr where
showsPrec d (Ptr x)
= P.showParen (d P.> 10)
$ case P.testBit x (P.finiteBitSize (undefined :: Key) - 1) of
True -> P.showString "Leaf " . P.showsPrec 11 (P.clearBit x (P.finiteBitSize (undefined :: Key) - 1))
False -> P.showString "Inner " . P.showsPrec 11 x
pattern Ptr_ :: Exp Int -> Exp Ptr
pattern Ptr_ x = Pattern x
{-# COMPLETE Ptr_ #-}
type Key = Word
binary_radix_tree :: Acc (Vector Key) -> Acc (Vector Node)
binary_radix_tree keys = zipWith4 Node_ deltas lefts rights parents
where
n = length keys
bits = finiteBitSize (undef @Key)
delta i j =
if j >= 0 && j < n
then
let li = keys !! i
lj = keys !! j
in if li == lj
then bits + countLeadingZeros (i `xor` j)
else countLeadingZeros (li `xor` lj)
else -1
node i =
let
d = signum $ delta i (i+1) - delta i (i-1)
delta_min = delta i (i-d)
l_max = while (\l_max' -> delta i (i+l_max'*d) > delta_min)
(*4)
128
T2 l _ = while (\(T2 _ t) -> t > 0)
(\(T2 l' t) ->
let t2 = t `quot` 2 in
if delta i (i+(l'+t) * d) > delta_min
then T2 (l' + t) t2
else T2 l' t2)
(T2 0 (l_max `quot` 2))
j = i + l*d
delta_node = delta i j
T2 s _ = while (\(T2 _ q) -> q <= l)
(\(T2 s' q) ->
let r = q*2
t = (l + r - 1) `quot` r
in if delta i (i+(s'+t)*d) > delta_node
then T2 (s'+t) r
else T2 s' r)
(T2 0 1)
gamma = i + s*d + min d 0
T2 left left_parent =
if min i j == gamma
then T2 (leaf gamma) (-1)
else T2 (inner gamma) gamma
T2 right right_parent =
if max i j == gamma + 1
then T2 (leaf (gamma+1)) (-1)
else T2 (inner (gamma+1)) (gamma+1)
leaf x = Ptr_ (setBit x (bits-1))
inner x = Ptr_ x
in
T5 (fromIntegral delta_node :: Exp Word8)
left
right
left_parent
right_parent
(deltas, lefts, rights, left_parents, right_parents)
= unzip5
$ generate (I1 (n-1)) (node . unindex1)
parents
= let from = generate (I1 ((n-1)*2)) (\(I1 i) -> i < n-1 ? (i, i-n+1))
dest = left_parents ++ right_parents
in permute const
(fill (I1 (n-1)) undef)
(\ix -> let d = dest ! ix in
if d < 0 then Nothing_ else Just_ (I1 d))
from