{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeApplications #-}

module Crc32c
  ( bytes
  , mutableBytes
  , chunks
  ) where

import Control.Monad.Primitive (PrimMonad, PrimState)
import Crc32c.Table (table)
import Data.Bits (shiftR, xor)
import Data.Bytes.Chunks (Chunks (ChunksCons, ChunksNil))
import Data.Bytes.Types (Bytes (Bytes), MutableBytes (MutableBytes))
import qualified Data.Primitive.ByteArray as PM
import qualified Data.Primitive.Ptr as PM
import Data.Word (Word32, Word8)

-- | Compute the checksum of a slice of bytes.
bytes :: Word32 -> Bytes -> Word32
bytes :: Word32 -> Bytes -> Word32
bytes !Word32
acc0 (Bytes ByteArray
arr Int
off Int
len) =
  let go :: Word32 -> Int -> Int -> Word32
go !Word32
acc !Int
ix !Int
end =
        if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
          then Word32 -> Int -> Int -> Word32
go (Word32 -> Word8 -> Word32
step Word32
acc (ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
PM.indexByteArray ByteArray
arr Int
ix)) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
end
          else Word32
acc
   in Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
0xFFFFFFFF (Word32 -> Int -> Int -> Word32
go (Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
acc0 Word32
0xFFFFFFFF) Int
off (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len))

chunks :: Word32 -> Chunks -> Word32
chunks :: Word32 -> Chunks -> Word32
chunks !Word32
acc Chunks
ChunksNil = Word32
acc
chunks !Word32
acc (ChunksCons Bytes
x Chunks
xs) =
  let !acc' :: Word32
acc' = Word32 -> Bytes -> Word32
bytes Word32
acc Bytes
x
   in Word32 -> Chunks -> Word32
chunks Word32
acc' Chunks
xs

-- | Compute the checksum of a slice of mutable bytes.
mutableBytes ::
  (PrimMonad m) =>
  Word32 ->
  MutableBytes (PrimState m) ->
  m Word32
{-# INLINEABLE mutableBytes #-}
mutableBytes :: forall (m :: * -> *).
PrimMonad m =>
Word32 -> MutableBytes (PrimState m) -> m Word32
mutableBytes Word32
acc0 (MutableBytes MutableByteArray (PrimState m)
arr Int
off Int
len) = do
  let go :: Word32 -> Int -> Int -> m Word32
go !Word32
acc !Int
ix !Int
end =
        if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
          then do
            Word8
w <- MutableByteArray (PrimState m) -> Int -> m Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PM.readByteArray MutableByteArray (PrimState m)
MutableByteArray (PrimState m)
arr Int
ix
            Word32 -> Int -> Int -> m Word32
go (Word32 -> Word8 -> Word32
step Word32
acc Word8
w) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
end
          else Word32 -> m Word32
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Word32
acc
  Word32
r <- Word32 -> Int -> Int -> m Word32
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Word32 -> Int -> Int -> m Word32
go (Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
acc0 Word32
0xFFFFFFFF) Int
off (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len)
  Word32 -> m Word32
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
0xFFFFFFFF Word32
r)

-- This might be revived one day.
--
-- x -- | Compute the checksum of a slice into an array of unsliced byte arrays.
-- x byteArrays :: Word32 -> UnliftedVector ByteArray -> Word32
-- x byteArrays !acc0 (UnliftedVector arr off len) =
-- x   let go !acc !ix !end = if ix < end
-- x         then
-- x           let b = PM.indexUnliftedArray arr ix
-- x            in go (bytes acc (Bytes b 0 (PM.sizeofByteArray b))) (ix + 1) end
-- x         else acc
-- x    in go acc0 off (off + len)
-- x
-- x -- | Compute the checksum of a slice into an mutable array of
-- x -- unsliced byte arrays.
-- x mutableByteArrays :: PrimMonad m
-- x   => Word32
-- x   -> MutableUnliftedVector (PrimState m) ByteArray
-- x   -> m Word32
-- x {-# inlineable mutableByteArrays #-}
-- x mutableByteArrays acc0 (MutableUnliftedVector arr off len) =
-- x   let go !acc !ix !end = if ix < end
-- x         then do
-- x           b <- PM.readUnliftedArray arr ix
-- x           go (bytes acc (Bytes b 0 (PM.sizeofByteArray b))) (ix + 1) end
-- x         else pure acc
-- x    in go acc0 off (off + len)

step :: Word32 -> Word8 -> Word32
step :: Word32 -> Word8 -> Word32
step !Word32
acc !Word8
w =
  Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor
    (Word8 -> Word32
scramble (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Word8 Word32
acc) Word8
w))
    (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
acc Int
8)

scramble :: Word8 -> Word32
scramble :: Word8 -> Word32
scramble Word8
w = Ptr Word32 -> Int -> Word32
forall a. Prim a => Ptr a -> Int -> a
PM.indexOffPtr Ptr Word32
table (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word8 @Int Word8
w)