{-# LANGUAGE BangPatterns, CPP, MagicHash, UnboxedTuples, UnliftedFFITypes, DeriveDataTypeable #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Primitive.ByteArray (
ByteArray(..), MutableByteArray(..), ByteArray#, MutableByteArray#,
newByteArray, newPinnedByteArray, newAlignedPinnedByteArray,
resizeMutableByteArray,
readByteArray, writeByteArray, indexByteArray,
byteArrayFromList, byteArrayFromListN,
foldrByteArray,
unsafeFreezeByteArray, unsafeThawByteArray,
copyByteArray, copyMutableByteArray,
#if __GLASGOW_HASKELL__ >= 708
copyByteArrayToAddr, copyMutableByteArrayToAddr,
#endif
moveByteArray,
setByteArray, fillByteArray,
sizeofByteArray,
sizeofMutableByteArray, getSizeofMutableByteArray, sameMutableByteArray,
#if __GLASGOW_HASKELL__ >= 802
isByteArrayPinned, isMutableByteArrayPinned,
#endif
byteArrayContents, mutableByteArrayContents
) where
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Primitive.Types
import Foreign.C.Types
import Data.Word ( Word8 )
import GHC.Base ( Int(..) )
#if __GLASGOW_HASKELL__ >= 708
import qualified GHC.Exts as Exts ( IsList(..) )
#endif
import GHC.Prim
#if __GLASGOW_HASKELL__ >= 706
hiding (setByteArray#)
#endif
import Data.Typeable ( Typeable )
import Data.Data ( Data(..) )
import Data.Primitive.Internal.Compat ( isTrue#, mkNoRepType )
import Numeric
#if MIN_VERSION_base(4,9,0)
import qualified Data.Semigroup as SG
import qualified Data.Foldable as F
#endif
#if !(MIN_VERSION_base(4,8,0))
import Data.Monoid (Monoid(..))
#endif
#if __GLASGOW_HASKELL__ >= 802
import GHC.Exts as Exts (isByteArrayPinned#,isMutableByteArrayPinned#)
#endif
#if __GLASGOW_HASKELL__ >= 804
import GHC.Exts (compareByteArrays#)
#else
import System.IO.Unsafe (unsafeDupablePerformIO)
#endif
data ByteArray = ByteArray ByteArray# deriving ( Typeable )
data MutableByteArray s = MutableByteArray (MutableByteArray# s)
deriving( Typeable )
newByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m))
{-# INLINE newByteArray #-}
newByteArray (I# n#)
= primitive (\s# -> case newByteArray# n# s# of
(# s'#, arr# #) -> (# s'#, MutableByteArray arr# #))
newPinnedByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m))
{-# INLINE newPinnedByteArray #-}
newPinnedByteArray (I# n#)
= primitive (\s# -> case newPinnedByteArray# n# s# of
(# s'#, arr# #) -> (# s'#, MutableByteArray arr# #))
newAlignedPinnedByteArray
:: PrimMonad m
=> Int
-> Int
-> m (MutableByteArray (PrimState m))
{-# INLINE newAlignedPinnedByteArray #-}
newAlignedPinnedByteArray (I# n#) (I# k#)
= primitive (\s# -> case newAlignedPinnedByteArray# n# k# s# of
(# s'#, arr# #) -> (# s'#, MutableByteArray arr# #))
byteArrayContents :: ByteArray -> Addr
{-# INLINE byteArrayContents #-}
byteArrayContents (ByteArray arr#) = Addr (byteArrayContents# arr#)
mutableByteArrayContents :: MutableByteArray s -> Addr
{-# INLINE mutableByteArrayContents #-}
mutableByteArrayContents (MutableByteArray arr#)
= Addr (byteArrayContents# (unsafeCoerce# arr#))
sameMutableByteArray :: MutableByteArray s -> MutableByteArray s -> Bool
{-# INLINE sameMutableByteArray #-}
sameMutableByteArray (MutableByteArray arr#) (MutableByteArray brr#)
= isTrue# (sameMutableByteArray# arr# brr#)
resizeMutableByteArray
:: PrimMonad m => MutableByteArray (PrimState m) -> Int
-> m (MutableByteArray (PrimState m))
{-# INLINE resizeMutableByteArray #-}
#if __GLASGOW_HASKELL__ >= 710
resizeMutableByteArray (MutableByteArray arr#) (I# n#)
= primitive (\s# -> case resizeMutableByteArray# arr# n# s# of
(# s'#, arr'# #) -> (# s'#, MutableByteArray arr'# #))
#else
resizeMutableByteArray arr n
= do arr' <- newByteArray n
copyMutableByteArray arr' 0 arr 0 (min (sizeofMutableByteArray arr) n)
return arr'
#endif
getSizeofMutableByteArray
:: PrimMonad m => MutableByteArray (PrimState m) -> m Int
{-# INLINE getSizeofMutableByteArray #-}
#if __GLASGOW_HASKELL__ >= 801
getSizeofMutableByteArray (MutableByteArray arr#)
= primitive (\s# -> case getSizeofMutableByteArray# arr# s# of
(# s'#, n# #) -> (# s'#, I# n# #))
#else
getSizeofMutableByteArray arr
= return (sizeofMutableByteArray arr)
#endif
unsafeFreezeByteArray
:: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray
{-# INLINE unsafeFreezeByteArray #-}
unsafeFreezeByteArray (MutableByteArray arr#)
= primitive (\s# -> case unsafeFreezeByteArray# arr# s# of
(# s'#, arr'# #) -> (# s'#, ByteArray arr'# #))
unsafeThawByteArray
:: PrimMonad m => ByteArray -> m (MutableByteArray (PrimState m))
{-# INLINE unsafeThawByteArray #-}
unsafeThawByteArray (ByteArray arr#)
= primitive (\s# -> (# s#, MutableByteArray (unsafeCoerce# arr#) #))
sizeofByteArray :: ByteArray -> Int
{-# INLINE sizeofByteArray #-}
sizeofByteArray (ByteArray arr#) = I# (sizeofByteArray# arr#)
sizeofMutableByteArray :: MutableByteArray s -> Int
{-# INLINE sizeofMutableByteArray #-}
sizeofMutableByteArray (MutableByteArray arr#) = I# (sizeofMutableByteArray# arr#)
#if __GLASGOW_HASKELL__ >= 802
isByteArrayPinned :: ByteArray -> Bool
{-# INLINE isByteArrayPinned #-}
isByteArrayPinned (ByteArray arr#) = isTrue# (Exts.isByteArrayPinned# arr#)
isMutableByteArrayPinned :: MutableByteArray s -> Bool
{-# INLINE isMutableByteArrayPinned #-}
isMutableByteArrayPinned (MutableByteArray marr#) = isTrue# (Exts.isMutableByteArrayPinned# marr#)
#endif
indexByteArray :: Prim a => ByteArray -> Int -> a
{-# INLINE indexByteArray #-}
indexByteArray (ByteArray arr#) (I# i#) = indexByteArray# arr# i#
readByteArray
:: (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> m a
{-# INLINE readByteArray #-}
readByteArray (MutableByteArray arr#) (I# i#)
= primitive (readByteArray# arr# i#)
writeByteArray
:: (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> a -> m ()
{-# INLINE writeByteArray #-}
writeByteArray (MutableByteArray arr#) (I# i#) x
= primitive_ (writeByteArray# arr# i# x)
foldrByteArray :: forall a b. (Prim a) => (a -> b -> b) -> b -> ByteArray -> b
foldrByteArray f z arr = go 0
where
go i
| sizeofByteArray arr > i * sz = f (indexByteArray arr i) (go (i+1))
| otherwise = z
sz = sizeOf (undefined :: a)
byteArrayFromList :: Prim a => [a] -> ByteArray
byteArrayFromList xs = byteArrayFromListN (length xs) xs
byteArrayFromListN :: Prim a => Int -> [a] -> ByteArray
byteArrayFromListN n ys = runST $ do
marr <- newByteArray (n * sizeOf (head ys))
let go !ix [] = if ix == n
then return ()
else die "byteArrayFromListN" "list length less than specified size"
go !ix (x : xs) = if ix < n
then do
writeByteArray marr ix x
go (ix + 1) xs
else die "byteArrayFromListN" "list length greater than specified size"
go 0 ys
unsafeFreezeByteArray marr
unI# :: Int -> Int#
unI# (I# n#) = n#
copyByteArray
:: PrimMonad m => MutableByteArray (PrimState m)
-> Int
-> ByteArray
-> Int
-> Int
-> m ()
{-# INLINE copyByteArray #-}
copyByteArray (MutableByteArray dst#) doff (ByteArray src#) soff sz
= primitive_ (copyByteArray# src# (unI# soff) dst# (unI# doff) (unI# sz))
copyMutableByteArray
:: PrimMonad m => MutableByteArray (PrimState m)
-> Int
-> MutableByteArray (PrimState m)
-> Int
-> Int
-> m ()
{-# INLINE copyMutableByteArray #-}
copyMutableByteArray (MutableByteArray dst#) doff
(MutableByteArray src#) soff sz
= primitive_ (copyMutableByteArray# src# (unI# soff) dst# (unI# doff) (unI# sz))
#if __GLASGOW_HASKELL__ >= 708
copyByteArrayToAddr
:: PrimMonad m
=> Addr
-> ByteArray
-> Int
-> Int
-> m ()
{-# INLINE copyByteArrayToAddr #-}
copyByteArrayToAddr (Addr dst#) (ByteArray src#) soff sz
= primitive_ (copyByteArrayToAddr# src# (unI# soff) dst# (unI# sz))
copyMutableByteArrayToAddr
:: PrimMonad m
=> Addr
-> MutableByteArray (PrimState m)
-> Int
-> Int
-> m ()
{-# INLINE copyMutableByteArrayToAddr #-}
copyMutableByteArrayToAddr (Addr dst#) (MutableByteArray src#) soff sz
= primitive_ (copyMutableByteArrayToAddr# src# (unI# soff) dst# (unI# sz))
#endif
moveByteArray
:: PrimMonad m => MutableByteArray (PrimState m)
-> Int
-> MutableByteArray (PrimState m)
-> Int
-> Int
-> m ()
{-# INLINE moveByteArray #-}
moveByteArray (MutableByteArray dst#) doff
(MutableByteArray src#) soff sz
= unsafePrimToPrim
$ memmove_mba dst# (fromIntegral doff) src# (fromIntegral soff)
(fromIntegral sz)
setByteArray
:: (Prim a, PrimMonad m) => MutableByteArray (PrimState m)
-> Int
-> Int
-> a
-> m ()
{-# INLINE setByteArray #-}
setByteArray (MutableByteArray dst#) (I# doff#) (I# sz#) x
= primitive_ (setByteArray# dst# doff# sz# x)
fillByteArray
:: PrimMonad m => MutableByteArray (PrimState m)
-> Int
-> Int
-> Word8
-> m ()
{-# INLINE fillByteArray #-}
fillByteArray = setByteArray
foreign import ccall unsafe "primitive-memops.h hsprimitive_memmove"
memmove_mba :: MutableByteArray# s -> CInt
-> MutableByteArray# s -> CInt
-> CSize -> IO ()
instance Data ByteArray where
toConstr _ = error "toConstr"
gunfold _ _ = error "gunfold"
dataTypeOf _ = mkNoRepType "Data.Primitive.ByteArray.ByteArray"
instance Typeable s => Data (MutableByteArray s) where
toConstr _ = error "toConstr"
gunfold _ _ = error "gunfold"
dataTypeOf _ = mkNoRepType "Data.Primitive.ByteArray.MutableByteArray"
instance Show ByteArray where
showsPrec _ ba =
showString "[" . go 0
where
go i
| i < sizeofByteArray ba = comma . showString "0x" . showHex (indexByteArray ba i :: Word8) . go (i+1)
| otherwise = showChar ']'
where
comma | i == 0 = id
| otherwise = showString ", "
compareByteArrays :: ByteArray -> ByteArray -> Int -> Ordering
{-# INLINE compareByteArrays #-}
#if __GLASGOW_HASKELL__ >= 804
compareByteArrays (ByteArray ba1#) (ByteArray ba2#) (I# n#) =
compare (I# (compareByteArrays# ba1# 0# ba2# 0# n#)) 0
#else
compareByteArrays (ByteArray ba1#) (ByteArray ba2#) (I# n#)
= compare (fromCInt (unsafeDupablePerformIO (memcmp_ba ba1# ba2# n))) 0
where
n = fromIntegral (I# n#) :: CSize
fromCInt = fromIntegral :: CInt -> Int
foreign import ccall unsafe "primitive-memops.h hsprimitive_memcmp"
memcmp_ba :: ByteArray# -> ByteArray# -> CSize -> IO CInt
#endif
sameByteArray :: ByteArray# -> ByteArray# -> Bool
sameByteArray ba1 ba2 =
case reallyUnsafePtrEquality# (unsafeCoerce# ba1 :: ()) (unsafeCoerce# ba2 :: ()) of
#if __GLASGOW_HASKELL__ >= 708
r -> isTrue# r
#else
1# -> True
0# -> False
#endif
instance Eq ByteArray where
ba1@(ByteArray ba1#) == ba2@(ByteArray ba2#)
| sameByteArray ba1# ba2# = True
| n1 /= n2 = False
| otherwise = compareByteArrays ba1 ba2 n1 == EQ
where
n1 = sizeofByteArray ba1
n2 = sizeofByteArray ba2
instance Ord ByteArray where
ba1@(ByteArray ba1#) `compare` ba2@(ByteArray ba2#)
| sameByteArray ba1# ba2# = EQ
| n1 /= n2 = n1 `compare` n2
| otherwise = compareByteArrays ba1 ba2 n1
where
n1 = sizeofByteArray ba1
n2 = sizeofByteArray ba2
appendByteArray :: ByteArray -> ByteArray -> ByteArray
appendByteArray a b = runST $ do
marr <- newByteArray (sizeofByteArray a + sizeofByteArray b)
copyByteArray marr 0 a 0 (sizeofByteArray a)
copyByteArray marr (sizeofByteArray a) b 0 (sizeofByteArray b)
unsafeFreezeByteArray marr
concatByteArray :: [ByteArray] -> ByteArray
concatByteArray arrs = runST $ do
let len = calcLength arrs 0
marr <- newByteArray len
pasteByteArrays marr 0 arrs
unsafeFreezeByteArray marr
pasteByteArrays :: MutableByteArray s -> Int -> [ByteArray] -> ST s ()
pasteByteArrays !_ !_ [] = return ()
pasteByteArrays !marr !ix (x : xs) = do
copyByteArray marr ix x 0 (sizeofByteArray x)
pasteByteArrays marr (ix + sizeofByteArray x) xs
calcLength :: [ByteArray] -> Int -> Int
calcLength [] !n = n
calcLength (x : xs) !n = calcLength xs (sizeofByteArray x + n)
emptyByteArray :: ByteArray
emptyByteArray = runST (newByteArray 0 >>= unsafeFreezeByteArray)
replicateByteArray :: Int -> ByteArray -> ByteArray
replicateByteArray n arr = runST $ do
marr <- newByteArray (n * sizeofByteArray arr)
let go i = if i < n
then do
copyByteArray marr (i * sizeofByteArray arr) arr 0 (sizeofByteArray arr)
go (i + 1)
else return ()
go 0
unsafeFreezeByteArray marr
#if MIN_VERSION_base(4,9,0)
instance SG.Semigroup ByteArray where
(<>) = appendByteArray
sconcat = mconcat . F.toList
stimes i arr
| itgr < 1 = emptyByteArray
| itgr <= (fromIntegral (maxBound :: Int)) = replicateByteArray (fromIntegral itgr) arr
| otherwise = error "Data.Primitive.ByteArray#stimes: cannot allocate the requested amount of memory"
where itgr = toInteger i :: Integer
#endif
instance Monoid ByteArray where
mempty = emptyByteArray
#if !(MIN_VERSION_base(4,11,0))
mappend = appendByteArray
#endif
mconcat = concatByteArray
#if __GLASGOW_HASKELL__ >= 708
instance Exts.IsList ByteArray where
type Item ByteArray = Word8
toList = foldrByteArray (:) []
fromList xs = byteArrayFromListN (length xs) xs
fromListN = byteArrayFromListN
#endif
die :: String -> String -> a
die fun problem = error $ "Data.Primitive.ByteArray." ++ fun ++ ": " ++ problem