module Hans.Utils.Checksum(
computeChecksum
, finalizeChecksum
, computePartialChecksum
, computePartialChecksumLazy
, clearChecksum
, pokeChecksum
, emptyPartialChecksum
)
where
import Control.Exception (assert)
import Data.Bits (Bits(shiftL,shiftR,complement,clearBit,(.&.),rotate))
import Data.List (foldl')
import Data.Word (Word8,Word16,Word32)
import Foreign.Storable (pokeByteOff)
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as S
clearChecksum :: S.ByteString -> Int -> IO S.ByteString
clearChecksum b off = S.unsafeUseAsCStringLen b $ \(ptr,len) -> do
assert (len > off + 1) (pokeByteOff ptr off (0 :: Word16))
return b
pokeChecksum :: Word16 -> S.ByteString -> Int -> IO S.ByteString
pokeChecksum cs b off = S.unsafeUseAsCStringLen b $ \(ptr,len) -> do
assert (off < len + 1) (pokeByteOff ptr off (rotate cs 8))
return b
data PartialChecksum = PartialChecksum
{ pcAccum :: !Word32
, pcCarry :: Maybe Word8
}
emptyPartialChecksum :: PartialChecksum
emptyPartialChecksum = PartialChecksum
{ pcAccum = 0
, pcCarry = Nothing
}
finalizeChecksum :: PartialChecksum -> Word16
finalizeChecksum pc = complement (fromIntegral (fold32 (fold32 result)))
where
result = case pcCarry pc of
Nothing -> pcAccum pc
Just prev -> step (pcAccum pc) prev 0
computeChecksum :: Word32 -> S.ByteString -> Word16
computeChecksum c0 = finalizeChecksum . computePartialChecksum PartialChecksum
{ pcAccum = c0
, pcCarry = Nothing
}
computePartialChecksumLazy :: PartialChecksum -> L.ByteString -> PartialChecksum
computePartialChecksumLazy c0 = foldl' computePartialChecksum c0 . L.toChunks
computePartialChecksum :: PartialChecksum -> S.ByteString -> PartialChecksum
computePartialChecksum pc b
| S.null b = pc
| otherwise = case pcCarry pc of
Nothing -> result
Just prev -> computePartialChecksum (pc
{ pcCarry = Nothing
, pcAccum = step (pcAccum pc) prev (S.unsafeIndex b 0)
}) (S.tail b)
where
!n' = S.length b
result = PartialChecksum
{ pcAccum = loop (fromIntegral (pcAccum pc)) 0
, pcCarry = carry
}
carry
| odd n' = Just (S.unsafeIndex b (n' 1))
| otherwise = Nothing
loop !acc off
| off < n = loop (step acc hi lo) (off + 2)
| otherwise = acc
where hi = S.unsafeIndex b off
lo = S.unsafeIndex b (off+1)
n = clearBit n' 0
step :: Word32 -> Word8 -> Word8 -> Word32
step acc hi lo = acc + fromIntegral hi `shiftL` 8 + fromIntegral lo
fold32 :: Word32 -> Word32
fold32 x = (x .&. 0xFFFF) + (x `shiftR` 16)