{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Data.HashMap (
HashMap, Hashable,
fromVector,
size,
member,
lookup,
insert, insertWith, insertWithKey,
delete,
adjust, adjustWithKey,
map,
mapWithKey,
keys,
elems,
assocs,
) where
import Data.Array.Accelerate hiding ( size, map )
import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Unsafe
import Data.Array.Accelerate.Data.Bits
import Data.Array.Accelerate.Data.Maybe
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Hashable
import Data.Array.Accelerate.Data.Tree.Radix
import Data.Array.Accelerate.Data.Sort.Quick
import Data.Function
data HashMap k v = HashMap (Vector Node) (Vector (k,v))
deriving (Show, Generic, Arrays)
pattern HashMap_
:: (Elt k, Elt v)
=> Acc (Vector Node)
-> Acc (Vector (k,v))
-> Acc (HashMap k v)
pattern HashMap_ t kv = Pattern (t,kv)
{-# COMPLETE HashMap_ #-}
size :: (Elt k, Elt v) => Acc (HashMap k v) -> Exp Int
size (HashMap_ _ kv) = length kv
member :: (Eq k, Hashable k, Elt v) => Exp k -> Acc (HashMap k v) -> Exp Bool
member k m =
if isJust (lookup k m)
then True_
else False_
lookup :: (Eq k, Hashable k, Elt v) => Exp k -> Acc (HashMap k v) -> Exp (Maybe v)
lookup k hm = snd `fmap` lookupWithIndex k hm
lookupWithIndex :: (Eq k, Hashable k, Elt v) => Exp k -> Acc (HashMap k v) -> Exp (Maybe (Int, v))
lookupWithIndex key (HashMap_ tree kv) = result
where
h = hash key
n = length tree
bits = finiteBitSize (undef @Key)
index (Ptr_ x) = clearBit x (bits - 1)
isLeaf (Ptr_ x) = testBit x (bits - 1)
result =
if length kv < 2
then if length kv == 0
then Nothing_
else let T2 k v = kv !! 0
in k == key ? (Just_ (T2 0 v), Nothing_)
else
snd $ while (\(T2 i _) -> i < n) search (T2 0 Nothing_)
search (T2 i _) =
let Node_ d l r p = tree !! i
d' = fromIntegral d
in if d' < bits
then let m = testBit h (bits - d' - 1) ? (r, l)
j = index m
in if isLeaf m
then let T2 k v = kv !! j
in T2 n (k == key ? (Just_ (T2 j v), Nothing_))
else T2 j Nothing_
else
let T3 _ _ x = while (\(T3 j _ c) -> isNothing c && j /= p)
exhaust
(T3 i (-1) Nothing_)
in T2 n x
exhaust (T3 i prev _) =
let Node_ _ l r p = tree !! i
fromLeft = index l == prev
fromRight = index r == prev
in if fromLeft
then
let j = index r
in if isLeaf r
then let T2 k v = kv !! j
in T3 i j (k == key ? (Just_ (T2 j v), Nothing_))
else T3 j i Nothing_
else
if fromRight
then
T3 p i Nothing_
else
let j = index l
in if isLeaf l
then let T2 k v = kv !! j
in T3 i j (k == key ? (Just_ (T2 j v), Nothing_))
else T3 j i Nothing_
insert :: (Eq k, Hashable k, Elt v)
=> Acc (Vector (k,v))
-> Acc (HashMap k v)
-> Acc (HashMap k v)
insert = insertWith const
insertWith
:: (Eq k, Hashable k, Elt v)
=> (Exp v -> Exp v -> Exp v)
-> Acc (Vector (k,v))
-> Acc (HashMap k v)
-> Acc (HashMap k v)
insertWith f = insertWithKey (const f)
insertWithKey
:: (Eq k, Hashable k, Elt v)
=> (Exp k -> Exp v -> Exp v -> Exp v)
-> Acc (Vector (k,v))
-> Acc (HashMap k v)
-> Acc (HashMap k v)
insertWithKey f kv hm@(HashMap_ tree kv0) =
let
old = if the sz == length kv
then kv0
else permute const kv0 (\ix -> let i = is ! ix in i < 0 ? (Nothing_, Just_ (I1 i))) kv'
(is, kv') = unzip
$ A.map (\(T2 k v) -> let mu = lookupWithIndex k hm
in if isJust mu
then let T2 i u = fromJust mu
in T2 i (T2 k (f k v u))
else T2 (-1) undef) kv
T2 new sz = filter (\(T2 i _) -> i < 0)
$ zip is kv
in
if the sz == 0
then HashMap_ tree old
else fromVector (old ++ A.map snd new)
delete :: (Eq k, Hashable k, Elt v)
=> Acc (Vector k)
-> Acc (HashMap k v)
-> Acc (HashMap k v)
delete ks hm =
let
T2 is sz = justs
$ A.map (\k -> let mu = lookupWithIndex k hm
in if isJust mu
then let T2 i _ = fromJust mu
in Just_ i
else Nothing_) ks
T2 kv' _ = justs
. scatter is (A.map Just_ (assocs hm))
$ fill (shape is) Nothing_
h' = A.map (bitcast . hash . fst) kv'
tree' = binary_radix_tree h'
in
if the sz == 0
then hm
else HashMap_ tree' kv'
adjust :: (Eq k, Hashable k, Elt v)
=> (Exp v -> Exp v)
-> Acc (Vector k)
-> Acc (HashMap k v)
-> Acc (HashMap k v)
adjust f = adjustWithKey (const f)
adjustWithKey
:: (Eq k, Hashable k, Elt v)
=> (Exp k -> Exp v -> Exp v)
-> Acc (Vector k)
-> Acc (HashMap k v)
-> Acc (HashMap k v)
adjustWithKey f ks hm@(HashMap_ tree kvs) =
let
(is, new) = unzip iv
T2 iv sz = justs
$ A.map (\k -> let mv = lookupWithIndex k hm
in if isJust mv
then let T2 i v = fromJust mv
in Just_ (T2 i (T2 k (f k v)))
else Nothing_) ks
in
if the sz == 0
then hm
else HashMap_ tree (scatter is kvs new)
map :: (Elt k, Elt v1, Elt v2) => (Exp v1 -> Exp v2) -> Acc (HashMap k v1) -> Acc (HashMap k v2)
map f = mapWithKey (const f)
mapWithKey :: (Elt k, Elt v1, Elt v2) => (Exp k -> Exp v1 -> Exp v2) -> Acc (HashMap k v1) -> Acc (HashMap k v2)
mapWithKey f (HashMap_ t kv)
= HashMap_ t
$ A.map (\(T2 k v) -> T2 k (f k v)) kv
keys :: (Elt k, Elt v) => Acc (HashMap k v) -> Acc (Vector k)
keys (HashMap_ _ kv) = A.map fst kv
elems :: (Elt k, Elt v) => Acc (HashMap k v) -> Acc (Vector v)
elems (HashMap_ _ kv) = A.map snd kv
assocs :: (Elt k, Elt v) => Acc (HashMap k v) -> Acc (Vector (k,v))
assocs (HashMap_ _ kv) = kv
fromVector :: (Hashable k, Elt v) => Acc (Vector (k,v)) -> Acc (HashMap k v)
fromVector v = HashMap_ tree kv
where
tree = binary_radix_tree h
kv = gather p v
(h, p) = unzip
. sortBy (compare `on` fst)
$ imap (\(I1 i) (T2 k _) -> T2 (bitcast (hash k)) i) v