{-# LANGUAGE BangPatterns
           , RankNTypes
           , ScopedTypeVariables
           , UnboxedTuples #-}

module Data.ByteArray.NonEmpty
  ( Step (..)

  , fromStep
  , toNonEmpty
  , toList

  , dropByteArray

  , appendByteArray
  , dropAppendByteArray
  , fromStepAppend

  , splitByteArray
  ) where

import           Control.Monad.ST
import           Data.Primitive.ByteArray
import           Data.List.NonEmpty (NonEmpty (..))
import           Data.Word



-- | Single step of destroying a key.
data Step a b = More a b
              | Done

{-# INLINE fromStep #-}
fromStep :: (x -> Step Word8 x) -> Word8 -> x -> ByteArray
fromStep :: forall x. (x -> Step Word8 x) -> Word8 -> x -> ByteArray
fromStep (x -> Step Word8 x
more :: x -> Step Word8 x) = \Word8
w0 -> Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go Int
1 (\MutableByteArray s
marr -> MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
0 Word8
w0)
  where
    go :: Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
    go :: Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go !Int
n forall s. MutableByteArray s -> ST s ()
write x
s =
      case x -> Step Word8 x
more x
s of
        More Word8
w x
s' ->
          let write' :: MutableByteArray s -> ST s ()
write' MutableByteArray s
marr = do
                MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write MutableByteArray s
marr
                MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
n Word8
w

          in Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write' x
s'

        Step Word8 x
Done      ->
          (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
            MutableByteArray s
marr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
n
            MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write MutableByteArray s
marr
            MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr



{-# INLINE toNonEmpty #-}
toNonEmpty :: ByteArray -> NonEmpty Word8
toNonEmpty :: ByteArray -> NonEmpty Word8
toNonEmpty ByteArray
arr = ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr Int
0 Word8 -> [Word8] -> NonEmpty Word8
forall a. a -> [a] -> NonEmpty a
:| Int -> ByteArray -> [Word8]
toListFrom Int
1 ByteArray
arr

{-# INLINE toList #-}
toList :: ByteArray -> [Word8]
toList :: ByteArray -> [Word8]
toList = Int -> ByteArray -> [Word8]
toListFrom Int
0

{-# INLINE toListFrom #-}
toListFrom :: Int -> ByteArray -> [Word8]
toListFrom :: Int -> ByteArray -> [Word8]
toListFrom Int
n0 ByteArray
arr = Int -> [Word8]
forall {a}. Prim a => Int -> [a]
go Int
n0
  where
    go :: Int -> [a]
go Int
n
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteArray -> Int
sizeofByteArray ByteArray
arr = []
      | Bool
otherwise                = ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr Int
n a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> [a]
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)



dropByteArray :: Int -> ByteArray -> ByteArray
dropByteArray :: Int -> ByteArray -> ByteArray
dropByteArray Int
n ByteArray
arr =
  (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    let len :: Int
len = ByteArray -> Int
sizeofByteArray ByteArray
arr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
    MutableByteArray s
mbrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
len
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr Int
0 ByteArray
arr Int
n Int
len
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr



appendByteArray :: ByteArray -> ByteArray -> ByteArray
appendByteArray :: ByteArray -> ByteArray -> ByteArray
appendByteArray ByteArray
arr ByteArray
brr =
  (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    let alen :: Int
alen = ByteArray -> Int
sizeofByteArray ByteArray
arr
        blen :: Int
blen = ByteArray -> Int
sizeofByteArray ByteArray
brr
    MutableByteArray s
mcrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
alen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
blen)
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr Int
0    ByteArray
arr Int
0 Int
alen
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr Int
alen ByteArray
brr Int
0 Int
blen
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr



dropAppendByteArray :: Int -> ByteArray -> ByteArray -> ByteArray
dropAppendByteArray :: Int -> ByteArray -> ByteArray -> ByteArray
dropAppendByteArray Int
n ByteArray
arr ByteArray
brr =
  (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    let alen :: Int
alen = ByteArray -> Int
sizeofByteArray ByteArray
arr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
        blen :: Int
blen = ByteArray -> Int
sizeofByteArray ByteArray
brr
    MutableByteArray s
mcrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
alen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
blen)
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr Int
0    ByteArray
arr Int
n Int
alen
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr Int
alen ByteArray
brr Int
0 Int
blen
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr



{-# INLINE fromStepAppend #-}
fromStepAppend :: (x -> Step Word8 x) -> Word8 -> x -> ByteArray -> ByteArray
fromStepAppend :: forall x.
(x -> Step Word8 x) -> Word8 -> x -> ByteArray -> ByteArray
fromStepAppend (x -> Step Word8 x
more :: x -> Step Word8 x) = \Word8
w0 x
s0 ByteArray
arr ->
  let go :: Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
      go :: Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go !Int
n forall s. MutableByteArray s -> ST s ()
write x
s =
        case x -> Step Word8 x
more x
s of
          More Word8
w x
s' ->
            let write' :: MutableByteArray s -> ST s ()
write' MutableByteArray s
mbrr = do
                  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr Int
n Word8
w
                  MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write MutableByteArray s
mbrr

            in Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write' x
s'

          Step Word8 x
Done      ->
            (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
              let alen :: Int
alen = ByteArray -> Int
sizeofByteArray ByteArray
arr
              MutableByteArray s
mbrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
alen)
              MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
write MutableByteArray s
mbrr
              MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr Int
n ByteArray
arr Int
0 Int
alen
              MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr

  in Int -> (forall s. MutableByteArray s -> ST s ()) -> x -> ByteArray
go Int
1 (\MutableByteArray s
mbrr -> MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr Int
0 Word8
w0) x
s0



data Wrap = Wrap {-# UNPACK #-} !ByteArray {-# UNPACK #-} !ByteArray

splitByteArray :: Int -> Int -> ByteArray -> (# ByteArray, ByteArray #)
splitByteArray :: Int -> Int -> ByteArray -> (# ByteArray, ByteArray #)
splitByteArray Int
offset Int
n ByteArray
arr =
  let f :: Wrap
f = (forall s. ST s Wrap) -> Wrap
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Wrap) -> Wrap) -> (forall s. ST s Wrap) -> Wrap
forall a b. (a -> b) -> a -> b
$ do
            let alen :: Int
alen = ByteArray -> Int
sizeofByteArray ByteArray
arr

            MutableByteArray s
mbrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
n
            MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr Int
0 ByteArray
arr Int
offset Int
n
            ByteArray
brr <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbrr

            let clen :: Int
clen = Int
alen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n

            MutableByteArray s
mcrr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
clen
            MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr Int
0 ByteArray
arr Int
n Int
clen
            ByteArray
crr <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mcrr

            Wrap -> ST s Wrap
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Wrap -> ST s Wrap) -> Wrap -> ST s Wrap
forall a b. (a -> b) -> a -> b
$ ByteArray -> ByteArray -> Wrap
Wrap ByteArray
brr ByteArray
crr

  in case Wrap
f of
       Wrap ByteArray
brr ByteArray
crr -> (# ByteArray
brr, ByteArray
crr #)