{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
module Data.HashTable.ST.Linear
( HashTable
, new
, newSized
, delete
, lookup
, insert
, mutate
, mutateST
, mapM_
, foldM
, computeOverhead
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
import Data.Word
#endif
import Control.Monad hiding (foldM, mapM_)
import Control.Monad.ST
import Data.Bits
import Data.Hashable
import Data.STRef
import Prelude hiding (lookup, mapM_)
import qualified Data.HashTable.Class as C
import Data.HashTable.Internal.Array
import Data.HashTable.Internal.Linear.Bucket (Bucket)
import qualified Data.HashTable.Internal.Linear.Bucket as Bucket
import Data.HashTable.Internal.Utils
#ifdef DEBUG
import System.IO
#endif
newtype HashTable s k v = HT (STRef s (HashTable_ s k v))
data HashTable_ s k v = HashTable
{ _level :: {-# UNPACK #-} !Int
, _splitptr :: {-# UNPACK #-} !Int
, _buckets :: {-# UNPACK #-} !(MutableArray s (Bucket s k v))
}
instance C.HashTable HashTable where
new = new
newSized = newSized
insert = insert
delete = delete
lookup = lookup
foldM = foldM
mapM_ = mapM_
lookupIndex = lookupIndex
nextByIndex = nextByIndex
computeOverhead = computeOverhead
mutate = mutate
mutateST = mutateST
instance Show (HashTable s k v) where
show _ = "<HashTable>"
new :: ST s (HashTable s k v)
new = do
v <- Bucket.newBucketArray 2
newRef $ HashTable 1 0 v
newSized :: Int -> ST s (HashTable s k v)
newSized n = do
v <- Bucket.newBucketArray sz
newRef $ HashTable lvl 0 v
where
k = ceiling (fromIntegral n * fillFactor / fromIntegral bucketSplitSize)
lvl = max 1 (fromEnum $ log2 k)
sz = power2 lvl
delete :: (Hashable k, Eq k) =>
(HashTable s k v)
-> k
-> ST s ()
delete htRef !k = readRef htRef >>= work
where
work (HashTable lvl splitptr buckets) = do
let !h0 = hashKey lvl splitptr k
debug $ "delete: size=" ++ show (power2 lvl) ++ ", h0=" ++ show h0
++ "splitptr: " ++ show splitptr
delete' buckets h0 k
{-# INLINE delete #-}
lookup :: (Eq k, Hashable k) => (HashTable s k v) -> k -> ST s (Maybe v)
lookup htRef !k = readRef htRef >>= work
where
work (HashTable lvl splitptr buckets) = do
let h0 = hashKey lvl splitptr k
bucket <- readArray buckets h0
Bucket.lookup bucket k
{-# INLINE lookup #-}
insert :: (Eq k, Hashable k) =>
(HashTable s k v)
-> k
-> v
-> ST s ()
insert htRef k v = do
ht' <- readRef htRef >>= work
writeRef htRef ht'
where
work ht@(HashTable lvl splitptr buckets) = do
let !h0 = hashKey lvl splitptr k
delete' buckets h0 k
bsz <- primitiveInsert' buckets h0 k v
if checkOverflow bsz
then do
debug $ "insert: splitting"
h <- split ht
debug $ "insert: done splitting"
return h
else do
debug $ "insert: done"
return ht
{-# INLINE insert #-}
mutate :: (Eq k, Hashable k) =>
(HashTable s k v)
-> k
-> (Maybe v -> (Maybe v, a))
-> ST s a
mutate htRef k f = mutateST htRef k (pure . f)
{-# INLINE mutate #-}
mutateST :: (Eq k, Hashable k) =>
(HashTable s k v)
-> k
-> (Maybe v -> ST s (Maybe v, a))
-> ST s a
mutateST htRef k f = do
(ht, a) <- readRef htRef >>= work
writeRef htRef ht
return a
where
work ht@(HashTable lvl splitptr buckets) = do
let !h0 = hashKey lvl splitptr k
bucket <- readArray buckets h0
(!bsz, mbk, a) <- Bucket.mutateST bucket k f
maybe (return ())
(writeArray buckets h0)
mbk
if checkOverflow bsz
then do
ht' <- split ht
return (ht', a)
else return (ht, a)
mapM_ :: ((k,v) -> ST s b) -> HashTable s k v -> ST s ()
mapM_ f htRef = readRef htRef >>= work
where
work (HashTable lvl _ buckets) = go 0
where
!sz = power2 lvl
go !i | i >= sz = return ()
| otherwise = do
b <- readArray buckets i
Bucket.mapM_ f b
go $ i+1
foldM :: (a -> (k,v) -> ST s a)
-> a -> HashTable s k v
-> ST s a
foldM f seed0 htRef = readRef htRef >>= work
where
work (HashTable lvl _ buckets) = go seed0 0
where
!sz = power2 lvl
go !seed !i | i >= sz = return seed
| otherwise = do
b <- readArray buckets i
!seed' <- Bucket.foldM f seed b
go seed' $ i+1
computeOverhead :: HashTable s k v -> ST s Double
computeOverhead htRef = readRef htRef >>= work
where
work (HashTable lvl _ buckets) = do
(totElems, overhead) <- go 0 0 0
let n = fromIntegral totElems
let o = fromIntegral overhead
return $ (fromIntegral sz + constOverhead + o) / n
where
constOverhead = 5.0
!sz = power2 lvl
go !nelems !overhead !i | i >= sz = return (nelems, overhead)
| otherwise = do
b <- readArray buckets i
(!n,!o) <- Bucket.nelemsAndOverheadInWords b
let !n' = n + nelems
let !o' = o + overhead
go n' o' (i+1)
delete' :: Eq k =>
MutableArray s (Bucket s k v)
-> Int
-> k
-> ST s ()
delete' buckets h0 k = do
bucket <- readArray buckets h0
_ <- Bucket.delete bucket k
return ()
split :: (Hashable k) =>
(HashTable_ s k v)
-> ST s (HashTable_ s k v)
split ht@(HashTable lvl splitptr buckets) = do
debug $ "split: start: nbuck=" ++ show (power2 lvl)
++ ", splitptr=" ++ show splitptr
oldBucket <- readArray buckets splitptr
nelems <- Bucket.size oldBucket
let !bsz = max Bucket.newBucketSize $
ceiling $ (0.625 :: Double) * fromIntegral nelems
dbucket1 <- Bucket.emptyWithSize bsz
writeArray buckets splitptr dbucket1
let lvl2 = power2 lvl
let lvl1 = power2 $ lvl-1
(!buckets',!lvl',!sp') <-
if splitptr+1 >= lvl1
then do
debug $ "split: resizing bucket array"
let lvl3 = 2*lvl2
b <- Bucket.expandBucketArray lvl3 lvl2 buckets
debug $ "split: resizing bucket array: done"
return (b,lvl+1,0)
else return (buckets,lvl,splitptr+1)
let ht' = HashTable lvl' sp' buckets'
let splitOffs = splitptr + lvl1
db2 <- readArray buckets' splitOffs
db2sz <- Bucket.size db2
let db2sz' = db2sz + bsz
db2' <- Bucket.growBucketTo db2sz' db2
debug $ "growing bucket at " ++ show splitOffs ++ " to size "
++ show db2sz'
writeArray buckets' splitOffs db2'
debug $ "split: rehashing bucket"
let f = uncurry $ primitiveInsert ht'
forceSameType f (uncurry $ primitiveInsert ht)
Bucket.mapM_ f oldBucket
debug $ "split: done"
return ht'
checkOverflow :: Int -> Bool
checkOverflow sz = sz > bucketSplitSize
primitiveInsert :: (Hashable k) =>
(HashTable_ s k v)
-> k
-> v
-> ST s Int
primitiveInsert (HashTable lvl splitptr buckets) k v = do
debug $ "primitiveInsert start: nbuckets=" ++ show (power2 lvl)
let h0 = hashKey lvl splitptr k
primitiveInsert' buckets h0 k v
primitiveInsert' :: MutableArray s (Bucket s k v)
-> Int
-> k
-> v
-> ST s Int
primitiveInsert' buckets !h0 !k !v = do
debug $ "primitiveInsert': bucket number=" ++ show h0
bucket <- readArray buckets h0
debug $ "primitiveInsert': snoccing bucket"
(!hw,m) <- Bucket.snoc bucket k v
debug $ "primitiveInsert': bucket snoc'd"
maybe (return ())
(writeArray buckets h0)
m
return hw
fillFactor :: Double
fillFactor = 1.3
bucketSplitSize :: Int
bucketSplitSize = Bucket.bucketSplitSize
{-# INLINE power2 #-}
power2 :: Int -> Int
power2 i = 1 `iShiftL` i
{-# INLINE hashKey #-}
hashKey :: (Hashable k) => Int -> Int -> k -> Int
hashKey !lvl !splitptr !k = h1
where
!h0 = hashAtLvl (lvl-1) k
!h1 = if (h0 < splitptr)
then hashAtLvl lvl k
else h0
{-# INLINE hashAtLvl #-}
hashAtLvl :: (Hashable k) => Int -> k -> Int
hashAtLvl !lvl !k = h
where
!h = hashcode .&. mask
!hashcode = hash k
!mask = power2 lvl - 1
newRef :: HashTable_ s k v -> ST s (HashTable s k v)
newRef = liftM HT . newSTRef
writeRef :: HashTable s k v -> HashTable_ s k v -> ST s ()
writeRef (HT ref) ht = writeSTRef ref ht
readRef :: HashTable s k v -> ST s (HashTable_ s k v)
readRef (HT ref) = readSTRef ref
{-# INLINE debug #-}
debug :: String -> ST s ()
#ifdef DEBUG
debug s = unsafeIOToST $ do
putStrLn s
hFlush stdout
#else
#ifdef TESTSUITE
debug !s = do
let !_ = length s
return $! ()
#else
debug _ = return ()
#endif
#endif
lookupIndex :: (Eq k, Hashable k) => HashTable s k v -> k -> ST s (Maybe Word)
lookupIndex htRef !k = readRef htRef >>= work
where
work (HashTable lvl splitptr buckets) = do
let h0 = hashKey lvl splitptr k
bucket <- readArray buckets h0
mbIx <- Bucket.lookupIndex bucket k
return $! do ix <- mbIx
Just $! encodeIndex lvl h0 ix
{-# INLINE lookupIndex #-}
encodeIndex :: Int -> Int -> Int -> Word
encodeIndex lvl bucketIx elemIx =
fromIntegral bucketIx `Data.Bits.shiftL` indexOffset lvl .|.
fromIntegral elemIx
{-# INLINE encodeIndex #-}
decodeIndex :: Int -> Word -> (Int, Int)
decodeIndex lvl ix =
( fromIntegral (ix `Data.Bits.shiftR` offset)
, fromIntegral ( (bit offset - 1) .&. ix )
)
where offset = indexOffset lvl
{-# INLINE decodeIndex #-}
indexOffset :: Int -> Int
indexOffset lvl = finiteBitSize (0 :: Word) - lvl
{-# INLINE indexOffset #-}
nextByIndex :: HashTable s k v -> Word -> ST s (Maybe (Word,k,v))
nextByIndex htRef !k = readRef htRef >>= work
where
work (HashTable lvl _ buckets) = do
let (h0,ix) = decodeIndex lvl k
go h0 ix
where
bucketN = power2 lvl
go h ix
| h < 0 || bucketN <= h = return Nothing
| otherwise = do
bucket <- readArray buckets h
mb <- Bucket.elemAt bucket ix
case mb of
Just (k',v) ->
let !ix' = encodeIndex lvl h ix
in return (Just (ix', k', v))
Nothing -> go (h+1) 0
{-# INLINE nextByIndex #-}