{-# LANGUAGE ScopedTypeVariables #-}
module Hedgehog.Classes.Storable (storableLaws) where
import Hedgehog
import Hedgehog.Classes.Common
import Hedgehog.Internal.Gen (sample)
import qualified Data.List as List
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import GHC.Ptr (Ptr(..), plusPtr)
import Foreign.Storable (Storable(..))
import System.IO.Unsafe (unsafePerformIO)
storableLaws :: (Eq a, Show a, Storable a) => Gen a -> Laws
storableLaws gen = Laws "Storable"
[ ("Set-Get (you get back what you put in)", storableSetGet gen)
, ("Get-Set (putting back what you got out has no effect)", storableGetSet gen)
, ("List Conversion Roundtrips", storableList gen)
, ("peekElemOff a i ≡ peek (plusPtr a (i * sizeOf undefined))", storablePeekElem gen)
, ("peekElemOff a i x ≡ poke (plusPtr a (i * sizeOf undefined)) x ≡ id ", storablePokeElem gen)
, ("peekByteOff a i ≡ peek (plusPtr a i)", storablePeekByte gen)
, ("peekByteOff a i x ≡ poke (plusPtr a i) x ≡ id ", storablePokeByte gen)
]
genArray :: forall a. (Storable a) => Gen a -> Int -> IO (Ptr a)
genArray gen len = do
let go ix xs = if ix == len
then pure xs
else do
x <- sample gen
go (ix + 1) (x : xs)
as <- go 0 []
newArray as
storablePeekElem :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storablePeekElem gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
let len = List.length as
ix <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
addr <- genArray gen len
x <- peekElemOff addr ix
y <- peek (addr `plusPtr` (ix * sizeOf (undefined :: a)))
free addr
pure (x === y)
storablePokeElem :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storablePokeElem gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
x <- forAll gen
let len = List.length as
ix <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
addr <- genArray gen len
pokeElemOff addr ix x
u <- peekElemOff addr ix
poke (addr `plusPtr` (ix * sizeOf x)) x
v <- peekElemOff addr ix
free addr
pure (u === v)
storablePeekByte :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storablePeekByte gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
let len = List.length as
off <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
addr <- genArray gen len
x :: a <- peekByteOff addr off
y :: a <- peek (addr `plusPtr` off)
free addr
pure (x === y)
storablePokeByte :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storablePokeByte gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
x <- forAll gen
let len = List.length as
off <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
addr <- genArray gen len
pokeByteOff addr off x
u :: a <- peekByteOff addr off
poke (addr `plusPtr` off) x
v :: a <- peekByteOff addr off
free addr
pure (u === v)
storableSetGet :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storableSetGet gen = property $ do
a <- forAll gen
len <- forAll $ Gen.int (Range.linear 0 20)
ix <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
ptr <- genArray gen len
pokeElemOff ptr ix a
a' <- peekElemOff ptr ix
free ptr
pure (a === a')
storableGetSet :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storableGetSet gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
let len = List.length as
ix <- forAll $ Gen.int (Range.linear 0 (len - 1))
unsafePerformIO $ do
ptrA <- newArray as
ptrB <- genArray gen len
copyArray ptrB ptrA len
a <- peekElemOff ptrA ix
pokeElemOff ptrA ix a
res <- arrayEq ptrA ptrB len
free ptrA
free ptrB
pure (res === True)
storableList :: forall a. (Eq a, Show a, Storable a) => Gen a -> Property
storableList gen = property $ do
as <- forAll $ genSmallNonEmptyList gen
unsafePerformIO $ do
let len = List.length as
ptr <- newArray as
let rebuild :: Int -> IO [a]
rebuild ix = if ix < len
then (:) <$> peekElemOff ptr ix <*> rebuild (ix + 1)
else pure []
asNew <- rebuild 0
free ptr
pure (as === asNew)
arrayEq :: forall a. (Eq a, Storable a) => Ptr a -> Ptr a -> Int -> IO Bool
arrayEq ptrA ptrB len = go 0 where
go i = if i < len
then do
a <- peekElemOff ptrA i
b <- peekElemOff ptrB i
if a == b
then go (i + 1)
else pure False
else pure True