{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE Trustworthy #-}
module Data.XOR
(
xor32StrictByteString
, xor32StrictByteString'
, xor32LazyByteString
, xor32ShortByteString
, xor32CStringLen
, xor8StrictByteString
, xor8LazyByteString
, xor8ShortByteString
, xor8CStringLen
) where
import Control.Exception (assert)
import Control.Monad (void)
import Control.Monad.ST (ST, runST)
import Data.Bits
import Data.Tuple (swap)
import Endianness (ByteOrder (..), Word32, Word8, byteSwap32,
targetByteOrder)
import Foreign.C (CStringLen)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, alignPtr, castPtr, minusPtr, plusPtr)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe (unsafeDupablePerformIO)
import qualified GHC.Exts as X
import qualified GHC.ST as X
import qualified GHC.Word as X
import qualified Data.ByteString as BS
import Data.ByteString.Internal (mallocByteString, memcpy)
import qualified Data.ByteString.Internal as BS (ByteString (..))
import qualified Data.ByteString.Lazy.Internal as BL (ByteString (..))
import qualified Data.ByteString.Short as SBS
import Data.ByteString.Short.Internal (ShortByteString (SBS))
xor32StrictByteString :: Word32 -> BS.ByteString -> BS.ByteString
xor32StrictByteString 0 bs = bs
xor32StrictByteString _ bs | BS.null bs = bs
xor32StrictByteString msk bs = fst (xor32StrictByteString'' msk bs)
xor32StrictByteString' :: Word32 -> BS.ByteString -> (Word32,BS.ByteString)
xor32StrictByteString' 0 bs = (0,bs)
xor32StrictByteString' msk bs | BS.null bs = (msk,bs)
xor32StrictByteString' msk bs = swap (xor32StrictByteString'' msk bs)
xor32LazyByteString :: Word32 -> BL.ByteString -> BL.ByteString
xor32LazyByteString 0 = id
xor32LazyByteString msk0 = go msk0
where
go _ BL.Empty = BL.Empty
go msk (BL.Chunk x xs) = BL.Chunk x' (go msk' xs)
where
(x',msk') = xor32StrictByteString'' msk x
{-# INLINE xor32StrictByteString'' #-}
xor32StrictByteString'' :: Word32 -> BS.ByteString -> (BS.ByteString,Word32)
xor32StrictByteString'' msk0 (BS.PS x s l)
= unsafeCreate' l $ \p8 ->
withForeignPtr x $ \f -> do
memcpy p8 (f `plusPtr` s) (fromIntegral l)
case remPtr p8 4 of
0 -> do
let trailer = l `rem` 4
lbytes = l - trailer
xor32PtrAligned msk0 (castPtr p8) lbytes
xor32PtrNonAligned msk0 (p8 `plusPtr` lbytes) trailer
_ ->
xor32Ptr msk0 p8 l
xor32ShortByteString :: Word32 -> SBS.ShortByteString -> SBS.ShortByteString
xor32ShortByteString 0 sbs = sbs
xor32ShortByteString _ sbs | SBS.null sbs = sbs
xor32ShortByteString mask0be sbs = runST $ do
tmp <- newSBS len
let loop4 i
| i == len4 = return ()
| otherwise = writeWord32Array tmp i (indexWord32Array sbs i `xor` mask0) >> loop4 (i+1)
loop4 0
let writeXor8 ofs msk8 = writeWord8Array tmp ofs (indexWord8Array sbs ofs `xor` msk8)
case len1 of
0 -> return ()
1 -> do
writeXor8 (len-1) (fromIntegral (shiftR mask0be 24))
2 -> do
writeXor8 (len-2) (fromIntegral (shiftR mask0be 24))
writeXor8 (len-1) (fromIntegral (shiftR mask0be 16))
3 -> do
writeXor8 (len-3) (fromIntegral (shiftR mask0be 24))
writeXor8 (len-2) (fromIntegral (shiftR mask0be 16))
writeXor8 (len-1) (fromIntegral (shiftR mask0be 8))
_ -> undefined
unsafeFreezeSBS tmp
where
len = SBS.length sbs
(len4,len1) = quotRem len 4
mask0 = case targetByteOrder of
LittleEndian -> byteSwap32 mask0be
BigEndian -> mask0be
{-# INLINEABLE xor32CStringLen #-}
xor32CStringLen :: Word32 -> CStringLen -> IO Word32
xor32CStringLen m (p,l) = xor32Ptr m (castPtr p) l
{-# INLINEABLE xor32Ptr #-}
xor32Ptr :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr 0 !_ !_ = return 0
xor32Ptr !mask0 !_ 0 = return mask0
xor32Ptr !mask0 !p0 !n
| n < 4 = xor32PtrNonAligned mask0 p0 n
| n < 0 = fail "xor32Ptr: negative size argument not supported"
xor32Ptr !mask0 !p0 !n
| assert (p0 <= p1 && p1 <= p2 && p2 <= p3 && n0 < 4 && n2 < 4) False = undefined
| n1 == 0 = xor32PtrNonAligned mask0 p0 n
| n0 == 0 = do
xor32PtrAligned mask0 p1 n1
xor32PtrNonAligned mask0 p2 n2
| otherwise = do
mask1 <- xor32PtrNonAligned mask0 p0 n0
xor32PtrAligned mask1 p1 n1
xor32PtrNonAligned mask1 p2 n2
where
p1 = castPtr (alignPtr p0 d)
p2 = alignPtrDown p3 d
p3 = plusPtr p0 n
d = 4
n0 = p1 `minusPtr` p0
n1 = p2 `minusPtr` p1
n2 = p3 `minusPtr` p2
xor32PtrNonAligned :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned mask0 _ 0 = return mask0
xor32PtrNonAligned mask0 p 1 = do
let mask1 = rotateL mask0 8
xor8Ptr1 (fromIntegral mask1) p
return mask1
xor32PtrNonAligned mask0 p 2 = do
xor8Ptr1 (fromIntegral (mask0 `shiftR` 24)) p
let mask1 = mask0 `rotateL` 16
xor8Ptr1 (fromIntegral mask1) (p `plusPtr` 1)
return mask1
xor32PtrNonAligned mask0 p 3 = do
xor8Ptr1 (fromIntegral (mask0 `shiftR` 24)) p
xor8Ptr1 (fromIntegral (mask0 `shiftR` 16)) (p `plusPtr` 1)
let mask1 = mask0 `rotateL` 24
xor8Ptr1 (fromIntegral mask1) (p `plusPtr` 2)
return mask1
xor32PtrNonAligned mask0 p0 n = go mask0 p0
where
p' = p0 `plusPtr` n
go m p
| p == p' = return m
| otherwise = do
let m' = rotateL m 8
xor8Ptr1 (fromIntegral m') p
go m' (p `plusPtr` 1)
xor32PtrAligned :: Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned _ _ 0 = return ()
xor32PtrAligned mask0be p0 n
= assert (p0 `remPtr` 4 == 0 && n `rem` 4 == 0) $ go p0
where
p' = p0 `plusPtr` n
go p
| p == p' = return ()
| otherwise = do { xor32Ptr1 mask0 p; go (p `plusPtr` 4) }
mask0 = case targetByteOrder of
LittleEndian -> byteSwap32 mask0be
BigEndian -> mask0be
remPtr :: Ptr a -> Int -> Int
remPtr (X.Ptr x) (X.I# d) = X.I# (X.remAddr# x d)
alignPtrDown :: Ptr a -> Int -> Ptr a
alignPtrDown p i
= case remPtr p i of
0 -> p
n -> plusPtr p (negate n)
xor8Ptr1 :: Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }
xor32Ptr1 :: Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }
{-# INLINE unsafeCreate' #-}
unsafeCreate' :: Int -> (Ptr Word8 -> IO a) -> (BS.ByteString, a)
unsafeCreate' l0 f0 = unsafeDupablePerformIO (create' l0 f0)
where
{-# INLINE create' #-}
create' :: Int -> (Ptr Word8 -> IO a) -> IO (BS.ByteString, a)
create' l f = do
fp <- mallocByteString l
res <- withForeignPtr fp $ \p -> f p
return (BS.PS fp 0 l, res)
expandW8ToW32 :: Word8 -> Word32
expandW8ToW32 x = x' .|. (x' `shiftL` 16)
where
x' = fromIntegral x .|. (fromIntegral x `shiftL` 8)
xor8StrictByteString :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString x = xor32StrictByteString (expandW8ToW32 x)
xor8LazyByteString :: Word8 -> BL.ByteString -> BL.ByteString
xor8LazyByteString x = xor32LazyByteString (expandW8ToW32 x)
xor8ShortByteString :: Word8 -> SBS.ShortByteString -> SBS.ShortByteString
xor8ShortByteString x = xor32ShortByteString (expandW8ToW32 x)
xor8CStringLen :: Word8 -> CStringLen -> IO ()
xor8CStringLen x (p,l) = void (xor32Ptr (expandW8ToW32 x) (castPtr p) l)
data MShortByteString s = MSBS (X.MutableByteArray# s)
newSBS :: Int -> ST s (MShortByteString s)
newSBS (X.I# len#) = X.ST $ \s0 -> case X.newByteArray# len# s0 of (# s, mba# #) -> (# s, MSBS mba# #)
indexWord8Array :: ShortByteString -> Int -> Word8
indexWord8Array (SBS ba#) (X.I# i#) = X.W8# (X.indexWord8Array# ba# i#)
writeWord8Array :: MShortByteString s -> Int -> Word8 -> ST s ()
writeWord8Array (MSBS mba#) (X.I# i#) (X.W8# w#) = X.ST $ \s0 -> case X.writeWord8Array# mba# i# w# s0 of s -> (# s, () #)
indexWord32Array :: ShortByteString -> Int -> Word32
indexWord32Array (SBS ba#) (X.I# i#) = X.W32# (X.indexWord32Array# ba# i#)
writeWord32Array :: MShortByteString s -> Int -> Word32 -> ST s ()
writeWord32Array (MSBS mba#) (X.I# i#) (X.W32# w#) = X.ST $ \s0 -> case X.writeWord32Array# mba# i# w# s0 of s -> (# s, () #)
unsafeFreezeSBS :: MShortByteString s -> ST s ShortByteString
unsafeFreezeSBS (MSBS mba#) = X.ST $ \s0 -> case X.unsafeFreezeByteArray# mba# s0 of (# s, ba# #) -> (# s, SBS ba# #)