{-# LANGUAGE CPP, BangPatterns, ForeignFunctionInterface #-}
module Data.Vector.Compact.WordVec
(
WordVec(..)
, Shape(..)
, vecShape , vecShape'
, vecLen , vecBits , vecIsSmall
, showWordVec , showsPrecWordVec
, null , empty
, singleton , isSingleton
, fromList , fromListN , fromList'
, toList , toRevList
, unsafeIndex , safeIndex
, head , tail , cons , uncons
, last , snoc
, concat
, sum , maximum
, eqStrict , eqExtZero
, cmpStrict , cmpExtZero
, lessOrEqual , partialSumsLessOrEqual
, add , subtract
, scale
, partialSums
, fold
, naiveMap , boundedMap
, naiveZipWith , boundedZipWith , listZipWith
, bitsNeededFor , bitsNeededFor'
, roundBits
)
where
import Prelude hiding ( head , tail , init , last , null , concat , subtract , sum , maximum )
import qualified Data.List as L
import Data.Bits
import Data.Word
import Foreign.C
import Data.Vector.Compact.Blob hiding ( head , tail , last )
import qualified Data.Vector.Compact.Blob as Blob
#ifdef x86_64_HOST_ARCH
#define MACHINE_WORD_BITS 64
#elif i386_HOST_ARCH
#define MACHINE_WORD_BITS 32
#elif i686_HOST_ARCH
#define MACHINE_WORD_BITS 32
#elif aarch64_HOST_ARCH
#define MACHINE_WORD_BITS 64
#else
#define MACHINE_WORD_BITS 32
#endif
newtype WordVec
= WordVec Blob
data Shape = Shape
{ shapeLen :: !Int
, shapeBits :: !Int
}
deriving (Eq,Show)
vecShape :: WordVec -> Shape
vecShape = snd . vecShape'
vecShape' :: WordVec -> (Bool,Shape)
vecShape' (WordVec blob) = (isSmall,shape) where
!h = Blob.head blob
!h2 = shiftR h 1
!isSmall = (h .&. 1) == 0
shape = if isSmall
then mkShape (shiftR h 3 .&. 31 ) (shiftL ((h2.&. 3)+1) 2)
else mkShape (shiftR h 5 .&. 0x07ffffff) (shiftL ((h2.&.15)+1) 2)
mkShape :: Word64 -> Word64 -> Shape
mkShape !x !y = Shape (fromIntegral x) (fromIntegral y)
vecIsSmall :: WordVec -> Bool
vecIsSmall (WordVec !blob) = (Blob.head blob .&. 1) == 0
vecLen :: WordVec -> Int
vecLen = shapeLen . vecShape
vecBits :: WordVec -> Int
vecBits = shapeBits . vecShape
instance Show WordVec where
showsPrec = showsPrecWordVec
showWordVec :: WordVec -> String
showWordVec dynvec = showsPrecWordVec 0 dynvec []
showsPrecWordVec :: Int -> WordVec -> ShowS
showsPrecWordVec prec dynvec
= showParen (prec > 10)
$ showString "fromList' "
. showsPrec 11 (vecShape dynvec)
. showChar ' '
. shows (toList dynvec)
instance Eq WordVec where
(==) x y = eqStrict x y
instance Ord WordVec where
compare x y = cmpStrict x y
empty :: WordVec
empty = fromList []
null :: WordVec -> Bool
null (WordVec !blob) =
let !h = Blob.head blob
in (h .&. 0xf9 == 0) || (h .&. 0xffffffe1 == 1)
singleton :: Word -> WordVec
singleton !x = fromListN 1 x [x] where
isSingleton :: WordVec -> Maybe Word
isSingleton !v = case (vecLen v) of
1 -> Just (head v)
_ -> Nothing
unsafeIndex :: Int -> WordVec -> Word
unsafeIndex idx dynvec@(WordVec blob) =
case isSmall of
True -> extractSmallWord bits blob ( 8 + bits*idx)
False -> extractSmallWord bits blob (32 + bits*idx)
where
(isSmall, Shape _ bits) = vecShape' dynvec
safeIndex :: Int -> WordVec -> Maybe Word
safeIndex idx dynvec@(WordVec blob)
| idx < 0 = Nothing
| idx >= len = Nothing
| otherwise = Just $ case isSmall of
True -> extractSmallWord bits blob ( 8 + bits*idx)
False -> extractSmallWord bits blob (32 + bits*idx)
where
(isSmall, Shape len bits) = vecShape' dynvec
head :: WordVec -> Word
head dynvec@(WordVec blob)
| null dynvec = 0
| otherwise = case vecIsSmall dynvec of
True -> extractSmallWord bits blob 8
False -> extractSmallWord bits blob 32
where
bits = vecBits dynvec
last :: WordVec -> Word
last dynvec@(WordVec blob)
| len == 0 = 0
| otherwise = case isSmall of
True -> extractSmallWord bits blob ( 8 + bits*(len-1))
False -> extractSmallWord bits blob (32 + bits*(len-1))
where
(isSmall, Shape len bits) = vecShape' dynvec
tail :: WordVec -> WordVec
tail = tail_v2
cons :: Word -> WordVec -> WordVec
cons = cons_v2
snoc :: WordVec -> Word -> WordVec
snoc = snoc_v2
uncons :: WordVec -> Maybe (Word, WordVec)
uncons = uncons_v2
concat :: WordVec -> WordVec -> WordVec
concat u v = fromList' (Shape (lu+lv) (max bu bv)) (toList u ++ toList v) where
Shape lu bu = vecShape u
Shape lv bv = vecShape v
foreign import ccall unsafe "vec_identity" c_vec_identity :: CFun11_
foreign import ccall unsafe "vec_tail" c_vec_tail :: CFun11_
foreign import ccall unsafe "vec_head_tail" c_vec_head_tail :: CFun11 Word64
foreign import ccall unsafe "vec_cons" c_vec_cons :: Word64 -> CFun11_
foreign import ccall unsafe "vec_snoc" c_vec_snoc :: Word64 -> CFun11_
tail_v2 :: WordVec -> WordVec
tail_v2 (WordVec blob) = WordVec $ wrapCFun11_ c_vec_tail id blob
cons_v2 :: Word -> WordVec -> WordVec
cons_v2 y vec@(WordVec blob) = WordVec $ wrapCFun11_ (c_vec_cons (fromIntegral y)) f blob where
f !n = max (n+2) worstcase
len = vecLen vec
worstcase = shiftR (32 + bitsNeededFor y * (len+1) + 63) 6
snoc_v2 :: WordVec -> Word -> WordVec
snoc_v2 vec@(WordVec blob) y = WordVec $ wrapCFun11_ (c_vec_snoc (fromIntegral y)) f blob where
f !n = max (n+2) worstcase
len = vecLen vec
worstcase = shiftR (32 + bitsNeededFor y * (len+1) + 63) 6
uncons_v2 :: WordVec -> Maybe (Word,WordVec)
uncons_v2 vec@(WordVec blob) = if null vec
then Nothing
else let (hd,tl) = wrapCFun11 c_vec_head_tail id blob
in Just (fromIntegral hd , WordVec tl)
toList :: WordVec -> [Word]
toList dynvec@(WordVec blob) =
case isSmall of
True -> worker 8 len (shiftR header 8 : restOfWords)
False -> worker 32 len (shiftR header 32 : restOfWords)
where
isSmall = (header .&. 1) == 0
(header:restOfWords) = blobToWordList blob
Shape len bits = vecShape dynvec
the_mask = shiftL 1 bits - 1 :: Word64
mask :: Word64 -> Word
mask w = fromIntegral (w .&. the_mask)
worker !bitOfs !0 _ = []
worker !bitOfs !k [] = replicate k 0
worker !bitOfs !k (this:rest) =
let newOfs = bitOfs + bits
in case compare newOfs 64 of
LT -> (mask this) : worker newOfs (k-1) (shiftR this bits : rest)
EQ -> (mask this) : worker 0 (k-1) rest
GT -> case rest of
(that:rest') ->
let !newOfs' = newOfs - 64
!elem = mask (this .|. shiftL that (64-bitOfs))
in elem : worker newOfs' (k-1) (shiftR that newOfs' : rest')
[] -> error "WordVec/toList: FATAL ERROR! this should not happen"
toRevList :: WordVec -> [Word]
toRevList dynvec@(WordVec blob) =
case isSmall of
True -> [ extractSmallWord bits blob ( 8 + bits*(len-i)) | i<-[1..len] ]
False -> [ extractSmallWord bits blob (32 + bits*(len-i)) | i<-[1..len] ]
where
(isSmall, Shape len bits) = vecShape' dynvec
fromList :: [Word] -> WordVec
fromList [] = fromList' (Shape 0 4) []
fromList xs = fromList' (Shape l b) xs where
l = length xs
b = bitsNeededFor (L.maximum xs)
fromListN
:: Int
-> Word
-> [Word]
-> WordVec
fromListN len max = fromList' (Shape len (bitsNeededFor max))
fromList' :: Shape -> [Word] -> WordVec
fromList' (Shape len bits0) words
| bits <= 16 && len <= 31 = WordVec $ mkBlob (mkHeader 0 2) 8 words
| otherwise = WordVec $ mkBlob (mkHeader 1 4) 32 words
where
!bits = max 4 $ min 64 $ (bits0 + 3) .&. 0xfc
!bitsEnc = shiftR bits 2 - 1 :: Int
!content = bits*len :: Int
!mask = shiftL 1 bits - 1 :: Word64
mkHeader :: Word64 -> Int -> Word64
mkHeader !isSmall !resoBits = isSmall + fromIntegral (shiftL (bitsEnc + shiftL len resoBits) 1)
mkBlob !header !ofs words = blobFromWordListN (shiftR (ofs+content+63) 6)
$ worker len header ofs words
worker :: Int -> Word64 -> Int -> [Word] -> [Word64]
worker 0 !current !bitOfs _ = if bitOfs == 0 then [] else [current]
worker !k !current !bitOfs [] = worker k current bitOfs [0]
worker !k !current !bitOfs (this0:rest) =
let !this = (fromIntegral this0) .&. mask
!newOfs = bitOfs + bits
!current' = (shiftL this bitOfs) .|. current
in case compare newOfs 64 of
LT -> worker (k-1) current' newOfs rest
EQ -> current' : worker (k-1) 0 0 rest
GT -> let !newOfs' = newOfs - 64
in current' : worker (k-1) (shiftR this (64-bitOfs)) newOfs' rest
sum :: WordVec -> Word
sum (WordVec blob) = fromIntegral $ wrapCFun10 c_vec_sum blob
maximum :: WordVec -> Word
maximum (WordVec blob) = fromIntegral $ wrapCFun10 c_vec_max blob
foreign import ccall unsafe "vec_sum" c_vec_sum :: CFun10 Word64
foreign import ccall unsafe "vec_max" c_vec_max :: CFun10 Word64
foreign import ccall unsafe "vec_equal_strict" c_equal_strict :: CFun20 CInt
foreign import ccall unsafe "vec_equal_extzero" c_equal_extzero :: CFun20 CInt
foreign import ccall unsafe "vec_compare_strict" c_compare_strict :: CFun20 CInt
foreign import ccall unsafe "vec_compare_extzero" c_compare_extzero :: CFun20 CInt
foreign import ccall unsafe "vec_less_or_equal" c_less_or_equal :: CFun20 CInt
foreign import ccall unsafe "vec_partial_sums_less_or_equal" c_partial_sums_less_or_equal :: CFun20 CInt
eqStrict :: WordVec -> WordVec -> Bool
eqStrict (WordVec blob1) (WordVec blob2) = (0 /= wrapCFun20 c_equal_strict blob1 blob2)
eqExtZero :: WordVec -> WordVec -> Bool
eqExtZero (WordVec blob1) (WordVec blob2) = (0 /= wrapCFun20 c_equal_extzero blob1 blob2)
cintToOrdering :: CInt -> Ordering
cintToOrdering !k
| k < 0 = LT
| k > 0 = GT
| otherwise = EQ
cmpStrict :: WordVec -> WordVec -> Ordering
cmpStrict (WordVec blob1) (WordVec blob2) = cintToOrdering $ wrapCFun20 c_compare_strict blob1 blob2
cmpExtZero :: WordVec -> WordVec -> Ordering
cmpExtZero (WordVec blob1) (WordVec blob2) = cintToOrdering $ wrapCFun20 c_compare_extzero blob1 blob2
lessOrEqual :: WordVec -> WordVec -> Bool
lessOrEqual (WordVec blob1) (WordVec blob2) = (0 /= wrapCFun20 c_less_or_equal blob1 blob2)
partialSumsLessOrEqual :: WordVec -> WordVec -> Bool
partialSumsLessOrEqual (WordVec blob1) (WordVec blob2) =
(0 /= wrapCFun20 c_partial_sums_less_or_equal blob1 blob2)
foreign import ccall unsafe "vec_add" c_vec_add :: CFun21_
foreign import ccall unsafe "vec_sub_overflow" c_vec_sub_overflow :: CFun21 CInt
add :: WordVec -> WordVec -> WordVec
add vec1@(WordVec blob1) vec2@(WordVec blob2) = WordVec $ wrapCFun21_ c_vec_add f blob1 blob2 where
f _ _ = 1 + shiftR ( (max b1 b2 + 4)*(max l1 l2) + 63 ) 6
Shape !l1 !b1 = vecShape vec1
Shape !l2 !b2 = vecShape vec2
subtract :: WordVec -> WordVec -> Maybe WordVec
subtract vec1@(WordVec blob1) vec2@(WordVec blob2) =
case (wrapCFun21 c_vec_sub_overflow f blob1 blob2) of
(0 , blob3) -> Just (WordVec blob3)
(_ , _ ) -> Nothing
where
f _ _ = 1 + shiftR ( (max b1 b2 + 4)*(max l1 l2) + 63 ) 6
Shape !l1 !b1 = vecShape vec1
Shape !l2 !b2 = vecShape vec2
foreign import ccall unsafe "vec_scale" c_vec_scale :: Word64 -> CFun11_
scale :: Word -> WordVec -> WordVec
scale s vec@(WordVec blob) = WordVec $ wrapCFun11_ (c_vec_scale (fromIntegral s)) f blob where
f _ = shiftR (32 + len*newbits + 63) 6
Shape !len !bits = vecShape vec
bound = if s <= shiftL 1 (64-bits)
then (2^bits - 1) * s
else (2^64 - 1)
newbits = bitsNeededFor bound
foreign import ccall unsafe "vec_partial_sums" c_vec_partial_sums :: CFun11 Word64
partialSums :: WordVec -> WordVec
partialSums vec@(WordVec blob) = WordVec $ snd $ wrapCFun11 c_vec_partial_sums f blob where
f _ = shiftR (32 + len*newbits + 63) 6
Shape !len !bits = vecShape vec
bound = if len <= shiftL 1 (64-bits)
then (2^bits - 1) * (fromIntegral len :: Word)
else (2^64 - 1)
newbits = bitsNeededFor bound
fold :: (a -> Word -> a) -> a -> WordVec -> a
fold f x v = L.foldl' f x (toList v)
naiveMap :: (Word -> Word) -> WordVec -> WordVec
naiveMap f u = fromList (map f $ toList u)
boundedMap :: Word -> (Word -> Word) -> WordVec -> WordVec
boundedMap bound f vec = fromList' (Shape l bits) (toList vec) where
l = vecLen vec
bits = bitsNeededFor bound
naiveZipWith :: (Word -> Word -> Word) -> WordVec -> WordVec -> WordVec
naiveZipWith f u v = fromList $ L.zipWith f (toList u) (toList v)
boundedZipWith :: Word -> (Word -> Word -> Word) -> WordVec -> WordVec -> WordVec
boundedZipWith bound f vec1 vec2 = fromList' (Shape l bits) $ L.zipWith f (toList vec1) (toList vec2) where
l = min (vecLen vec1) (vecLen vec2)
bits = bitsNeededFor bound
listZipWith :: (Word -> Word -> a) -> WordVec -> WordVec -> [a]
listZipWith f u v = L.zipWith f (toList u) (toList v)
bitsNeededFor :: Word -> Int
bitsNeededFor = bitsNeededForHs
bitsNeededFor' :: Word -> Int
bitsNeededFor' = bitsNeededForHs'
bitsNeededForHs :: Word -> Int
bitsNeededForHs = roundBits . bitsNeededForHs'
bitsNeededForHs' :: Word -> Int
bitsNeededForHs' bound
| bound == 0 = 1
| bound+1 == 0 = MACHINE_WORD_BITS
| otherwise = ceilingLog2 (bound + 1)
where
ceilingLog2 :: Word -> Int
ceilingLog2 0 = 0
ceilingLog2 n = 1 + go (n-1) where
go 0 = -1
go k = 1 + go (shiftR k 1)
roundBits :: Int -> Int
roundBits 0 = 4
roundBits k = shiftL (shiftR (k+3) 2) 2