module Foreign.Marshal.Array.Guarded.Debug (
   create,
   alloca,
   ) where

import qualified Foreign.Marshal.Array.Guarded.Plain as Plain

import Foreign.Marshal.Array
         (mallocArray, allocaArray, pokeArray, copyArray, advancePtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Storable (Storable, peekByteOff, sizeOf)
import Foreign.Concurrent (newForeignPtr)
import Foreign.ForeignPtr (ForeignPtr)
import Foreign.Ptr (Ptr, castPtr)

import Control.Monad (when)

import Data.Foldable (for_)
import Data.Word (Word8)


{- |
Array creation with additional immutability check, electrical fence
and pollution of uncleaned memory.

The function checks that the array is not altered anymore after creation.
-}
create :: (Storable a) => Int -> (Ptr a -> IO b) -> IO (ForeignPtr a, b)
create = flip asTypeOf Plain.create $ \size f -> do
   let border = 64
   let fullSize = size + 2*border
   ptrApre <- mallocArray fullSize
   ptrsA@(_ptrApre, ptrA, _ptrApost) <- fillAll border size ptrApre
   result <- f ptrA
   checkAll border ptrsA
   ptrB <- mallocArray size
   copyArray ptrB ptrA size
   fmap (flip (,) result) $ newForeignPtr ptrA $ do
      for_ (take (arraySize ptrA size) [0..]) $ \i -> do
         a <- peekByteOff ptrA i
         b <- peekByteOff ptrB i
         when (a/=(b::Word8)) $
            error $ "immutable array was altered at byte position " ++ show i
      trash fullSize ptrApre
      free ptrApre
      free ptrB


alloca :: (Storable a) => Int -> (Ptr a -> IO b) -> IO b
alloca = flip asTypeOf Plain.alloca $ \size f -> do
   let border = 64
   let fullSize = size + 2*border
   allocaArray fullSize $ \ptrPre -> do
      ptrs@(_ptrPre, ptr, _ptrPost) <- fillAll border size ptrPre
      result <- f ptr
      checkAll border ptrs
      trash fullSize ptrPre
      return result


fillAll :: (Storable a) => Int -> Int -> Ptr a -> IO (Ptr a, Ptr a, Ptr a)
fillAll border size ptrPre = do
   let ptr = advancePtr ptrPre border
   let ptrPost = advancePtr ptr size
   fill ptrPre border [0xAB,0xAD,0xCA,0xFE]
   fill ptr size [0xDE,0xAD,0xF0,0x0D]
   fill ptrPost border [0xAB,0xAD,0xCA,0xFE]
   return (ptrPre, ptr, ptrPost)

checkAll :: (Storable a) => Int -> (Ptr a, Ptr a, Ptr a) -> IO ()
checkAll border (ptrPre, _ptr, ptrPost) = do
   check "leading"  ptrPre  border [0xAB,0xAD,0xCA,0xFE]
   check "trailing" ptrPost border [0xAB,0xAD,0xCA,0xFE]

trash :: (Storable a) => Int -> Ptr a -> IO ()
trash fullSize ptrPre = fill ptrPre fullSize [0xDE,0xAD,0xBE,0xEF]


{-# INLINE fill #-}
fill :: (Storable a) => Ptr a -> Int -> [Word8] -> IO ()
fill ptr n bytes =
   pokeArray (castPtr ptr) $ take (arraySize ptr n) $ cycle bytes

{-# INLINE check #-}
check :: (Storable a) => String -> Ptr a -> Int -> [Word8] -> IO ()
check name ptr n bytes =
   for_ (take (arraySize ptr n) $ zip [0..] $ cycle bytes) $ \(i,b) -> do
      a <- peekByteOff ptr i
      when (a/=(b::Word8)) $
         error $ "damaged " ++ name ++ " fence at position " ++ show i

arraySize :: (Storable a) => Ptr a -> Int -> Int
arraySize ptr n = arraySizeAux ptr n $ error "arraySize: undefined element"

{- |
Correct size computation should also respect padding caused by alignment.
However, mallocArray uses this simple arithmetic.
-}
arraySizeAux :: (Storable a) => Ptr a -> Int -> a -> Int
arraySizeAux _ n a = n * sizeOf a