{-# LANGUAGE UnboxedTuples #-}

-- |
-- Module      : Streamly.Internal.Data.MutByteArray.Type
-- Copyright   : (c) 2023 Composewell Technologies
-- License     : BSD3-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
module Streamly.Internal.Data.MutByteArray.Type
    (
    -- ** MutByteArray
      MutByteArray(..)
    , MutableByteArray
    , getMutableByteArray#

    -- ** Pinning
    , PinnedState(..)
    , isPinned
    , pin
    , unpin

    -- ** Allocation
    , nil
    , newBytesAs
    , new
    , pinnedNew
    , pinnedNewAlignedBytes

    -- ** Access
    , sizeOfMutableByteArray
    , putSliceUnsafe
    , cloneSliceUnsafeAs
    , cloneSliceUnsafe
    , pinnedCloneSliceUnsafe
    , asPtrUnsafe
    ) where

import Control.Monad.IO.Class (MonadIO(..))
#ifdef DEBUG
import Control.Monad (when)
import Debug.Trace (trace)
#endif
import GHC.Base (IO(..))
import System.IO.Unsafe (unsafePerformIO)

import GHC.Exts

--------------------------------------------------------------------------------
-- The ArrayContents type
--------------------------------------------------------------------------------

data PinnedState
    = Pinned
    | Unpinned

-- XXX can use UnliftedNewtypes

-- | A lifted mutable byte array type wrapping @MutableByteArray# RealWorld@.
-- This is a low level array used to back high level unboxed arrays and
-- serialized data.
data MutByteArray = MutByteArray (MutableByteArray# RealWorld)

{-# DEPRECATED MutableByteArray "Please use MutByteArray instead" #-}
type MutableByteArray = MutByteArray

{-# INLINE getMutableByteArray# #-}
getMutableByteArray# :: MutByteArray -> MutableByteArray# RealWorld
getMutableByteArray# :: MutByteArray -> MutableByteArray# RealWorld
getMutableByteArray# (MutByteArray MutableByteArray# RealWorld
mbarr) = MutableByteArray# RealWorld
mbarr

-- | Return the size of the array in bytes.
{-# INLINE sizeOfMutableByteArray #-}
sizeOfMutableByteArray :: MutByteArray -> IO Int
sizeOfMutableByteArray :: MutByteArray -> IO Int
sizeOfMutableByteArray (MutByteArray MutableByteArray# RealWorld
arr) =
    forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
        case forall d. MutableByteArray# d -> State# d -> (# State# d, Int# #)
getSizeofMutableByteArray# MutableByteArray# RealWorld
arr State# RealWorld
s of
            (# State# RealWorld
s1, Int#
i #) -> (# State# RealWorld
s1, Int# -> Int
I# Int#
i #)

{-# INLINE touch #-}
touch :: MutByteArray -> IO ()
touch :: MutByteArray -> IO ()
touch (MutByteArray MutableByteArray# RealWorld
contents) =
    forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case touch# :: forall a. a -> State# RealWorld -> State# RealWorld
touch# MutableByteArray# RealWorld
contents State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)

-- XXX We can provide another API for "unsafe" FFI calls passing an unlifted
-- pointer to the FFI call. For unsafe calls we do not need to pin the array.
-- We can pass an unlifted pointer to the FFI routine to avoid GC kicking in
-- before the pointer is wrapped.
--
-- From the GHC manual:
--
-- GHC, since version 8.4, guarantees that garbage collection will never occur
-- during an unsafe call, even in the bytecode interpreter, and further
-- guarantees that unsafe calls will be performed in the calling thread. Making
-- it safe to pass heap-allocated objects to unsafe functions.

-- | Use a @MutByteArray@ as @Ptr a@. This is useful when we want to pass
-- an array as a pointer to some operating system call or to a "safe" FFI call.
--
-- If the array is not pinned it is copied to pinned memory before passing it
-- to the monadic action.
--
-- /Performance Notes:/ Forces a copy if the array is not pinned. It is advised
-- that the programmer keeps this in mind and creates a pinned array
-- opportunistically before this operation occurs, to avoid the cost of a copy
-- if possible.
--
-- /Unsafe/ because of direct pointer operations. The user must ensure that
-- they are writing within the legal bounds of the array.
--
-- /Pre-release/
--
{-# INLINE asPtrUnsafe #-}
asPtrUnsafe :: MonadIO m => MutByteArray -> (Ptr a -> m b) -> m b
asPtrUnsafe :: forall (m :: * -> *) a b.
MonadIO m =>
MutByteArray -> (Ptr a -> m b) -> m b
asPtrUnsafe MutByteArray
arr Ptr a -> m b
f = do
  MutByteArray
contents <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ MutByteArray -> IO MutByteArray
pin MutByteArray
arr
  let !ptr :: Ptr a
ptr = forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents#
                     (unsafeCoerce# :: forall a b. a -> b
unsafeCoerce# (MutByteArray -> MutableByteArray# RealWorld
getMutableByteArray# MutByteArray
contents)))
  b
r <- Ptr a -> m b
f forall {a}. Ptr a
ptr
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ MutByteArray -> IO ()
touch MutByteArray
contents
  forall (m :: * -> *) a. Monad m => a -> m a
return b
r

--------------------------------------------------------------------------------
-- Creation
--------------------------------------------------------------------------------

{-# NOINLINE nil #-}
nil :: MutByteArray
nil :: MutByteArray
nil = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ Int -> IO MutByteArray
new Int
0

{-# INLINE new #-}
new :: Int -> IO MutByteArray
new :: Int -> IO MutByteArray
new Int
nbytes | Int
nbytes forall a. Ord a => a -> a -> Bool
< Int
0 =
  forall a. [Char] -> a
errorWithoutStackTrace [Char]
"newByteArray: size must be >= 0"
new (I# Int#
nbytes) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
nbytes State# RealWorld
s of
        (# State# RealWorld
s', MutableByteArray# RealWorld
mbarr# #) ->
           let c :: MutByteArray
c = MutableByteArray# RealWorld -> MutByteArray
MutByteArray MutableByteArray# RealWorld
mbarr#
            in (# State# RealWorld
s', MutByteArray
c #)

{-# INLINE pinnedNew #-}
pinnedNew :: Int -> IO MutByteArray
pinnedNew :: Int -> IO MutByteArray
pinnedNew Int
nbytes | Int
nbytes forall a. Ord a => a -> a -> Bool
< Int
0 =
  forall a. [Char] -> a
errorWithoutStackTrace [Char]
"pinnedNewByteArray: size must be >= 0"
pinnedNew (I# Int#
nbytes) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newPinnedByteArray# Int#
nbytes State# RealWorld
s of
        (# State# RealWorld
s', MutableByteArray# RealWorld
mbarr# #) ->
           let c :: MutByteArray
c = MutableByteArray# RealWorld -> MutByteArray
MutByteArray MutableByteArray# RealWorld
mbarr#
            in (# State# RealWorld
s', MutByteArray
c #)

{-# INLINE pinnedNewAlignedBytes #-}
pinnedNewAlignedBytes :: Int -> Int -> IO MutByteArray
pinnedNewAlignedBytes :: Int -> Int -> IO MutByteArray
pinnedNewAlignedBytes Int
nbytes Int
_align | Int
nbytes forall a. Ord a => a -> a -> Bool
< Int
0 =
  forall a. [Char] -> a
errorWithoutStackTrace [Char]
"pinnedNewAlignedBytes: size must be >= 0"
pinnedNewAlignedBytes (I# Int#
nbytes) (I# Int#
align) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
nbytes Int#
align State# RealWorld
s of
        (# State# RealWorld
s', MutableByteArray# RealWorld
mbarr# #) ->
           let c :: MutByteArray
c = MutableByteArray# RealWorld -> MutByteArray
MutByteArray MutableByteArray# RealWorld
mbarr#
            in (# State# RealWorld
s', MutByteArray
c #)

{-# INLINE newBytesAs #-}
newBytesAs :: PinnedState -> Int -> IO MutByteArray
newBytesAs :: PinnedState -> Int -> IO MutByteArray
newBytesAs PinnedState
Unpinned = Int -> IO MutByteArray
new
newBytesAs PinnedState
Pinned = Int -> IO MutByteArray
pinnedNew

-------------------------------------------------------------------------------
-- Copying
-------------------------------------------------------------------------------

-- | Put a sub range of a source array into a subrange of a destination array.
-- This is not safe as it does not check the bounds of neither the src array
-- nor the destination array.
{-# INLINE putSliceUnsafe #-}
putSliceUnsafe ::
       MonadIO m
    => MutByteArray
    -> Int
    -> MutByteArray
    -> Int
    -> Int
    -> m ()
putSliceUnsafe :: forall (m :: * -> *).
MonadIO m =>
MutByteArray -> Int -> MutByteArray -> Int -> Int -> m ()
putSliceUnsafe MutByteArray
src Int
srcStartBytes MutByteArray
dst Int
dstStartBytes Int
lenBytes = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
#ifdef DEBUG
    srcLen <- sizeOfMutableByteArray src
    dstLen <- sizeOfMutableByteArray dst
    when (srcLen - srcStartBytes < lenBytes)
        $ error $ "putSliceUnsafe: src overflow: start" ++ show srcStartBytes
            ++ " end " ++ show srcLen ++ " len " ++ show lenBytes
    when (dstLen - dstStartBytes < lenBytes)
        $ error $ "putSliceUnsafe: dst overflow: start" ++ show dstStartBytes
            ++ " end " ++ show dstLen ++ " len " ++ show lenBytes
#endif
    let !(I# Int#
srcStartBytes#) = Int
srcStartBytes
        !(I# Int#
dstStartBytes#) = Int
dstStartBytes
        !(I# Int#
lenBytes#) = Int
lenBytes
    let arrS# :: MutableByteArray# RealWorld
arrS# = MutByteArray -> MutableByteArray# RealWorld
getMutableByteArray# MutByteArray
src
        arrD# :: MutableByteArray# RealWorld
arrD# = MutByteArray -> MutableByteArray# RealWorld
getMutableByteArray# MutByteArray
dst
    forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s# -> (# forall d.
MutableByteArray# d
-> Int#
-> MutableByteArray# d
-> Int#
-> Int#
-> State# d
-> State# d
copyMutableByteArray#
                    MutableByteArray# RealWorld
arrS# Int#
srcStartBytes# MutableByteArray# RealWorld
arrD# Int#
dstStartBytes# Int#
lenBytes# State# RealWorld
s#
                , () #)

-- | Unsafe as it does not check whether the start offset and length supplied
-- are valid inside the array.
{-# INLINE cloneSliceUnsafeAs #-}
cloneSliceUnsafeAs :: MonadIO m =>
    PinnedState -> Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafeAs :: forall (m :: * -> *).
MonadIO m =>
PinnedState -> Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafeAs PinnedState
ps Int
srcOff Int
srcLen MutByteArray
src =
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
        MutByteArray
mba <- PinnedState -> Int -> IO MutByteArray
newBytesAs PinnedState
ps Int
srcLen
        forall (m :: * -> *).
MonadIO m =>
MutByteArray -> Int -> MutByteArray -> Int -> Int -> m ()
putSliceUnsafe MutByteArray
src Int
srcOff MutByteArray
mba Int
0 Int
srcLen
        forall (m :: * -> *) a. Monad m => a -> m a
return MutByteArray
mba

-- | @cloneSliceUnsafe offset len arr@ clones a slice of the supplied array
-- starting at the given offset and equal to the given length.
{-# INLINE cloneSliceUnsafe #-}
cloneSliceUnsafe :: MonadIO m => Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafe :: forall (m :: * -> *).
MonadIO m =>
Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafe = forall (m :: * -> *).
MonadIO m =>
PinnedState -> Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafeAs PinnedState
Unpinned

-- | @pinnedCloneSliceUnsafe offset len arr@
{-# INLINE pinnedCloneSliceUnsafe #-}
pinnedCloneSliceUnsafe :: MonadIO m =>
    Int -> Int -> MutByteArray -> m MutByteArray
pinnedCloneSliceUnsafe :: forall (m :: * -> *).
MonadIO m =>
Int -> Int -> MutByteArray -> m MutByteArray
pinnedCloneSliceUnsafe = forall (m :: * -> *).
MonadIO m =>
PinnedState -> Int -> Int -> MutByteArray -> m MutByteArray
cloneSliceUnsafeAs PinnedState
Pinned

-------------------------------------------------------------------------------
-- Pinning & Unpinning
-------------------------------------------------------------------------------

-- | Return 'True' if the array is allocated in pinned memory.
{-# INLINE isPinned #-}
isPinned :: MutByteArray -> Bool
isPinned :: MutByteArray -> Bool
isPinned (MutByteArray MutableByteArray# RealWorld
arr#) =
    let pinnedInt :: Int
pinnedInt = Int# -> Int
I# (forall d. MutableByteArray# d -> Int#
isMutableByteArrayPinned# MutableByteArray# RealWorld
arr#)
     in Int
pinnedInt forall a. Eq a => a -> a -> Bool
/= Int
0


{-# INLINE cloneMutableArrayWith# #-}
cloneMutableArrayWith#
    :: (Int# -> State# RealWorld -> (# State# RealWorld
                                     , MutableByteArray# RealWorld #))
    -> MutableByteArray# RealWorld
    -> State# RealWorld
    -> (# State# RealWorld, MutableByteArray# RealWorld #)
cloneMutableArrayWith# :: (Int#
 -> State# RealWorld
 -> (# State# RealWorld, MutableByteArray# RealWorld #))
-> MutableByteArray# RealWorld
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
cloneMutableArrayWith# Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
alloc# MutableByteArray# RealWorld
arr# State# RealWorld
s# =
    case forall d. MutableByteArray# d -> State# d -> (# State# d, Int# #)
getSizeofMutableByteArray# MutableByteArray# RealWorld
arr# State# RealWorld
s# of
        (# State# RealWorld
s1#, Int#
i# #) ->
            case Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
alloc# Int#
i# State# RealWorld
s1# of
                (# State# RealWorld
s2#, MutableByteArray# RealWorld
arr1# #) ->
                    case forall d.
MutableByteArray# d
-> Int#
-> MutableByteArray# d
-> Int#
-> Int#
-> State# d
-> State# d
copyMutableByteArray# MutableByteArray# RealWorld
arr# Int#
0# MutableByteArray# RealWorld
arr1# Int#
0# Int#
i# State# RealWorld
s2# of
                        State# RealWorld
s3# -> (# State# RealWorld
s3#, MutableByteArray# RealWorld
arr1# #)

-- | Return a copy of the array in pinned memory if unpinned, else return the
-- original array.
{-# INLINE pin #-}
pin :: MutByteArray -> IO MutByteArray
pin :: MutByteArray -> IO MutByteArray
pin arr :: MutByteArray
arr@(MutByteArray MutableByteArray# RealWorld
marr#) =
    if MutByteArray -> Bool
isPinned MutByteArray
arr
    then forall (m :: * -> *) a. Monad m => a -> m a
return MutByteArray
arr
    else
#ifdef DEBUG
      do
        -- XXX dump stack trace
        trace ("pin: Copying array") (return ())
#endif
        forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO
             forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s# ->
                   case (Int#
 -> State# RealWorld
 -> (# State# RealWorld, MutableByteArray# RealWorld #))
-> MutableByteArray# RealWorld
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
cloneMutableArrayWith# forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newPinnedByteArray# MutableByteArray# RealWorld
marr# State# RealWorld
s# of
                       (# State# RealWorld
s1#, MutableByteArray# RealWorld
marr1# #) -> (# State# RealWorld
s1#, MutableByteArray# RealWorld -> MutByteArray
MutByteArray MutableByteArray# RealWorld
marr1# #)

-- | Return a copy of the array in unpinned memory if pinned, else return the
-- original array.
{-# INLINE unpin #-}
unpin :: MutByteArray -> IO MutByteArray
unpin :: MutByteArray -> IO MutByteArray
unpin arr :: MutByteArray
arr@(MutByteArray MutableByteArray# RealWorld
marr#) =
    if Bool -> Bool
not (MutByteArray -> Bool
isPinned MutByteArray
arr)
    then forall (m :: * -> *) a. Monad m => a -> m a
return MutByteArray
arr
    else
#ifdef DEBUG
      do
        -- XXX dump stack trace
        trace ("unpin: Copying array") (return ())
#endif
        forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO
             forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s# ->
                   case (Int#
 -> State# RealWorld
 -> (# State# RealWorld, MutableByteArray# RealWorld #))
-> MutableByteArray# RealWorld
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
cloneMutableArrayWith# forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# MutableByteArray# RealWorld
marr# State# RealWorld
s# of
                       (# State# RealWorld
s1#, MutableByteArray# RealWorld
marr1# #) -> (# State# RealWorld
s1#, MutableByteArray# RealWorld -> MutByteArray
MutByteArray MutableByteArray# RealWorld
marr1# #)