{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures   #-}
{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE MonoLocalBinds   #-}
module Buffer
       ( BufferPtr
       , Buffer
       , withBufferPtr
       , unsafeWithBufferPtr
       , memsetBuffer
       -- ** Some unsafe functions
       , unsafeGetBufferPointer
       , bufferSize
       ) where

import GHC.TypeLits

import Raaz.Core
import Raaz.Core.Memory (Access(..))
import Implementation

-- | A buffer @buf :: Buffer n@ is a memory element that has enough
-- space for the @n@ blocks of the primitive and, if required by the
-- implementation, any additional blocks that might be used for
-- padding of the last chunk of message.
newtype Buffer (n :: Nat) = Buffer { Buffer n -> Ptr Word8
unBuffer :: Ptr Word8 }


-- | This takes a buffer pointer action and runs it with the underlying pointer associated with
-- the buffer. The action is supposed to use
unsafeWithBufferPtr :: KnownNat n
                    => (BufferPtr -> a)
                    -> Buffer n
                    -> a
unsafeWithBufferPtr :: (BufferPtr -> a) -> Buffer n -> a
unsafeWithBufferPtr BufferPtr -> a
action = AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> a
BufferPtr -> a
action (AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> a)
-> (Buffer n -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)))
-> Buffer n
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Buffer n -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall (n :: Nat). Buffer n -> BufferPtr
unsafeGetBufferPointer

-- | Run the action on the buffer pointer.
withBufferPtr :: KnownNat n
              => (BufferPtr -> BlockCount Prim -> a)
              -> Buffer n
              -> a
withBufferPtr :: (BufferPtr -> BlockCount Prim -> a) -> Buffer n -> a
withBufferPtr BufferPtr -> BlockCount Prim -> a
action Buffer n
buf = (BufferPtr -> a) -> Buffer n -> a
forall (n :: Nat) a.
KnownNat n =>
(BufferPtr -> a) -> Buffer n -> a
unsafeWithBufferPtr AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> a
BufferPtr -> a
act Buffer n
buf
  where act :: AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> a
act = (AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
 -> BlockCount Prim -> a)
-> BlockCount Prim
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> BlockCount Prim -> a
BufferPtr -> BlockCount Prim -> a
action (BlockCount Prim
 -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> a)
-> BlockCount Prim
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> a
forall a b. (a -> b) -> a -> b
$ Proxy (Buffer n) -> BlockCount Prim
forall (n :: Nat).
KnownNat n =>
Proxy (Buffer n) -> BlockCount Prim
bufferSize (Proxy (Buffer n) -> BlockCount Prim)
-> Proxy (Buffer n) -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ Buffer n -> Proxy (Buffer n)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Buffer n
buf

-- | Memset the given buffer.
memsetBuffer :: KnownNat n => Word8 -> Buffer n -> IO ()
memsetBuffer :: Word8 -> Buffer n -> IO ()
memsetBuffer = (AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
 -> BlockCount Prim -> IO ())
-> Buffer n -> IO ()
forall (n :: Nat) a.
KnownNat n =>
(BufferPtr -> BlockCount Prim -> a) -> Buffer n -> a
withBufferPtr ((AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
  -> BlockCount Prim -> IO ())
 -> Buffer n -> IO ())
-> (Word8
    -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
    -> BlockCount Prim
    -> IO ())
-> Word8
-> Buffer n
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
 -> Word8 -> BlockCount Prim -> IO ())
-> Word8
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> BlockCount Prim
-> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> Word8 -> BlockCount Prim -> IO ()
forall l (ptr :: * -> *) a.
(LengthUnit l, Pointer ptr) =>
ptr a -> Word8 -> l -> IO ()
memset


-- WARNING: Not to be exposed else can be confusing with
-- `bufferSize`. Internal function used by allocation.
actualBufferSize :: KnownNat n => Proxy (Buffer n) -> BlockCount Prim
actualBufferSize :: Proxy (Buffer n) -> BlockCount Prim
actualBufferSize Proxy (Buffer n)
bproxy = Proxy (Buffer n) -> BlockCount Prim
forall (n :: Nat).
KnownNat n =>
Proxy (Buffer n) -> BlockCount Prim
bufferSize Proxy (Buffer n)
bproxy BlockCount Prim -> BlockCount Prim -> BlockCount Prim
forall a. Semigroup a => a -> a -> a
<> BlockCount Prim
additionalBlocks

{-# INLINE bufferSize #-}
-- | The size of data (measured in blocks) that can be safely
-- processed inside this buffer.
bufferSize :: KnownNat n => Proxy (Buffer n) -> BlockCount Prim
bufferSize :: Proxy (Buffer n) -> BlockCount Prim
bufferSize = (Int -> Proxy Prim -> BlockCount Prim)
-> Proxy Prim -> Int -> BlockCount Prim
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Proxy Prim -> BlockCount Prim
forall p. Int -> Proxy p -> BlockCount p
blocksOf Proxy Prim
forall k (t :: k). Proxy t
Proxy (Int -> BlockCount Prim)
-> (Proxy (Buffer n) -> Int) -> Proxy (Buffer n) -> BlockCount Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int)
-> (Proxy (Buffer n) -> Integer) -> Proxy (Buffer n) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n -> Integer)
-> (Proxy (Buffer n) -> Proxy n) -> Proxy (Buffer n) -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy (Buffer n) -> Proxy n
forall (n :: Nat). Proxy (Buffer n) -> Proxy n
nProxy
  where nProxy :: Proxy (Buffer n) -> Proxy n
        nProxy :: Proxy (Buffer n) -> Proxy n
nProxy  Proxy (Buffer n)
_ = Proxy n
forall k (t :: k). Proxy t
Proxy

-- | Get the underlying pointer for the buffer.
unsafeGetBufferPointer :: Buffer n -> BufferPtr
unsafeGetBufferPointer :: Buffer n -> BufferPtr
unsafeGetBufferPointer = AlignedPtr BufferAlignment Word8
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall (ptr :: * -> *) a b. Pointer ptr => ptr a -> ptr b
castPointer (AlignedPtr BufferAlignment Word8
 -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)))
-> (Buffer n -> AlignedPtr BufferAlignment Word8)
-> Buffer n
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> AlignedPtr BufferAlignment Word8
forall a (n :: Nat).
(Storable a, KnownNat n) =>
Ptr a -> AlignedPtr n a
nextAlignedPtr (Ptr Word8 -> AlignedPtr BufferAlignment Word8)
-> (Buffer n -> Ptr Word8)
-> Buffer n
-> AlignedPtr BufferAlignment Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Buffer n -> Ptr Word8
forall (n :: Nat). Buffer n -> Ptr Word8
unBuffer


instance KnownNat n => Memory (Buffer n) where
  memoryAlloc :: Alloc (Buffer n)
memoryAlloc = Alloc (Buffer n)
allocThisBuffer
    where allocThisBuffer :: Alloc (Buffer n)
allocThisBuffer = Ptr Word8 -> Buffer n
forall (n :: Nat). Ptr Word8 -> Buffer n
Buffer (Ptr Word8 -> Buffer n)
-> TwistRF AllocField (BYTES Int) (Ptr Word8) -> Alloc (Buffer n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BYTES Int -> TwistRF AllocField (BYTES Int) (Ptr Word8)
forall l.
LengthUnit l =>
l -> TwistRF AllocField (BYTES Int) (Ptr Word8)
pointerAlloc BYTES Int
sz
          bufferProxy     :: Alloc (Buffer n) -> Proxy (Buffer n)
          bufferProxy :: Alloc (Buffer n) -> Proxy (Buffer n)
bufferProxy Alloc (Buffer n)
_   = Proxy (Buffer n)
forall k (t :: k). Proxy t
Proxy
          algn :: Alignment
algn            = Proxy (AlignedPtr BufferAlignment (Tuple 16 (LE Word32)))
-> Alignment
forall (n :: Nat) a.
KnownNat n =>
Proxy (AlignedPtr n a) -> Alignment
ptrAlignment (Proxy BufferPtr
forall k (t :: k). Proxy t
Proxy :: Proxy BufferPtr)
          sz :: BYTES Int
sz              = BlockCount Prim -> Alignment -> BYTES Int
forall l. LengthUnit l => l -> Alignment -> BYTES Int
atLeastAligned (Proxy (Buffer n) -> BlockCount Prim
forall (n :: Nat).
KnownNat n =>
Proxy (Buffer n) -> BlockCount Prim
actualBufferSize (Proxy (Buffer n) -> BlockCount Prim)
-> Proxy (Buffer n) -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ Alloc (Buffer n) -> Proxy (Buffer n)
forall (n :: Nat). Alloc (Buffer n) -> Proxy (Buffer n)
bufferProxy Alloc (Buffer n)
allocThisBuffer) Alignment
algn

  unsafeToPointer :: Buffer n -> Ptr Word8
unsafeToPointer = Buffer n -> Ptr Word8
forall (n :: Nat). Buffer n -> Ptr Word8
unBuffer


instance KnownNat n => ReadAccessible (Buffer n) where
  readAccess :: Buffer n -> [Access]
readAccess Buffer n
buf = (Ptr Word8 -> [Access])
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> [Access]
forall (ptr :: * -> *) a b something.
Pointer ptr =>
(Ptr a -> b) -> ptr something -> b
unsafeWithPointerCast Ptr Word8 -> [Access]
makeAccess (AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> [Access])
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> [Access]
forall a b. (a -> b) -> a -> b
$ Buffer n -> BufferPtr
forall (n :: Nat). Buffer n -> BufferPtr
unsafeGetBufferPointer Buffer n
buf
    where makeAccess :: Ptr Word8 -> [Access]
makeAccess Ptr Word8
bptr = [ Ptr Word8 -> BYTES Int -> Access
Access Ptr Word8
bptr (BYTES Int -> Access) -> BYTES Int -> Access
forall a b. (a -> b) -> a -> b
$ BlockCount Prim -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes (BlockCount Prim -> BYTES Int) -> BlockCount Prim -> BYTES Int
forall a b. (a -> b) -> a -> b
$ Proxy (Buffer n) -> BlockCount Prim
forall (n :: Nat).
KnownNat n =>
Proxy (Buffer n) -> BlockCount Prim
bufferSize (Proxy (Buffer n) -> BlockCount Prim)
-> Proxy (Buffer n) -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ Buffer n -> Proxy (Buffer n)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Buffer n
buf ]


  beforeReadAdjustment :: Buffer n -> IO ()
beforeReadAdjustment Buffer n
buf = (Ptr (Tuple 16 (LE Word32)) -> IO ())
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> IO ()
forall (ptr :: * -> *) a b.
Pointer ptr =>
(Ptr a -> b) -> ptr a -> b
unsafeWithPointer (Proxy Prim -> Int -> BlockPtr Prim -> IO ()
forall prim.
Primitive prim =>
Proxy prim -> Int -> BlockPtr prim -> IO ()
adjust (Proxy Prim
forall k (t :: k). Proxy t
Proxy :: Proxy Prim) Int
nelems)
                                 (AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> IO ())
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer n -> BufferPtr
forall (n :: Nat). Buffer n -> BufferPtr
unsafeGetBufferPointer Buffer n
buf
    where getProxy :: Buffer n -> Proxy n
          getProxy :: Buffer n -> Proxy n
getProxy Buffer n
_ = Proxy n
forall k (t :: k). Proxy t
Proxy
          nelems :: Int
nelems     = Integer -> Int
forall a. Enum a => a -> Int
fromEnum (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n -> Integer) -> Proxy n -> Integer
forall a b. (a -> b) -> a -> b
$ Buffer n -> Proxy n
forall (n :: Nat). Buffer n -> Proxy n
getProxy Buffer n
buf
          adjust     :: Primitive prim => Proxy prim -> Int -> BlockPtr prim  -> IO ()
          adjust :: Proxy prim -> Int -> BlockPtr prim -> IO ()
adjust Proxy prim
_ Int
n BlockPtr prim
ptr  = BlockPtr prim -> Int -> IO ()
forall w. EndianStore w => Ptr w -> Int -> IO ()
adjustEndian BlockPtr prim
ptr Int
n


instance KnownNat n => WriteAccessible (Buffer n) where
  writeAccess :: Buffer n -> [Access]
writeAccess = Buffer n -> [Access]
forall mem. ReadAccessible mem => mem -> [Access]
readAccess
  afterWriteAdjustment :: Buffer n -> IO ()
afterWriteAdjustment = Buffer n -> IO ()
forall mem. ReadAccessible mem => mem -> IO ()
beforeReadAdjustment