{-# language BangPatterns #-}
{-# language LambdaCase #-}
module Data.Bytes.Mutable
(
MutableBytes
, takeWhile
, dropWhile
, unsafeTake
, unsafeDrop
, fromMutableByteArray
) where
import Prelude hiding (takeWhile,dropWhile)
import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Primitive (MutableByteArray)
import Data.Word (Word8)
import Control.Monad.Primitive (PrimMonad,PrimState)
import qualified Data.Primitive as PM
takeWhile :: PrimMonad m
=> (Word8 -> m Bool)
-> MutableBytes (PrimState m)
-> m (MutableBytes (PrimState m))
{-# inline takeWhile #-}
takeWhile k b = do
n <- countWhile k b
pure (unsafeTake n b)
dropWhile :: PrimMonad m
=> (Word8 -> m Bool)
-> MutableBytes (PrimState m)
-> m (MutableBytes (PrimState m))
{-# inline dropWhile #-}
dropWhile k b = do
n <- countWhile k b
pure (unsafeDrop n b)
unsafeTake :: Int -> MutableBytes s -> MutableBytes s
{-# inline unsafeTake #-}
unsafeTake n (MutableBytes arr off _) =
MutableBytes arr off n
unsafeDrop :: Int -> MutableBytes s -> MutableBytes s
{-# inline unsafeDrop #-}
unsafeDrop n (MutableBytes arr off len) =
MutableBytes arr (off + n) (len - n)
fromMutableByteArray :: PrimMonad m
=> MutableByteArray (PrimState m)
-> m (MutableBytes (PrimState m))
{-# inline fromMutableByteArray #-}
fromMutableByteArray mba = do
sz <- PM.getSizeofMutableByteArray mba
pure (MutableBytes mba 0 sz)
countWhile :: PrimMonad m
=> (Word8 -> m Bool)
-> MutableBytes (PrimState m)
-> m Int
{-# inline countWhile #-}
countWhile k (MutableBytes arr off0 len0) = go off0 len0 0 where
go !off !len !n = if len > 0
then (k =<< PM.readByteArray arr off) >>= \case
True -> go (off + 1) (len - 1) (n + 1)
False -> pure n
else pure n