--------------------------------------------------------------------------------
-- | Masking of fragmes using a simple XOR algorithm
{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings        #-}
{-# LANGUAGE ScopedTypeVariables      #-}
module Network.WebSockets.Hybi13.Mask
    ( Mask
    , parseMask
    , encodeMask
    , randomMask

    , maskPayload
    ) where


--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder       as Builder
import qualified Data.ByteString.Builder.Extra as Builder
import           Data.Binary.Get               (Get, getWord32host)
import qualified Data.ByteString.Internal      as B
import qualified Data.ByteString.Lazy          as BL
import qualified Data.ByteString.Lazy.Internal as BL
import           Data.Word                     (Word32, Word8)
import           Foreign.C.Types               (CChar (..), CInt (..),
                                                CSize (..))
import           Foreign.ForeignPtr            (withForeignPtr)
import           Foreign.Ptr                   (Ptr, plusPtr)
import           System.Random                 (RandomGen, random)


--------------------------------------------------------------------------------
foreign import ccall unsafe "_hs_mask_chunk" c_mask_chunk
    :: Word32 -> CInt -> Ptr CChar -> CSize -> Ptr Word8 -> IO ()


--------------------------------------------------------------------------------
-- | A mask is sequence of 4 bytes.  We store this in a 'Word32' in the host's
-- native byte ordering.
newtype Mask = Mask {Mask -> Word32
unMask :: Word32}


--------------------------------------------------------------------------------
-- | Parse a mask.
parseMask :: Get Mask
parseMask :: Get Mask
parseMask = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Word32 -> Mask
Mask Get Word32
getWord32host


--------------------------------------------------------------------------------
-- | Encode a mask
encodeMask :: Mask -> Builder.Builder
encodeMask :: Mask -> Builder
encodeMask = Word32 -> Builder
Builder.word32Host forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mask -> Word32
unMask


--------------------------------------------------------------------------------
-- | Create a random mask
randomMask :: forall g. RandomGen g => g -> (Mask, g)
randomMask :: forall g. RandomGen g => g -> (Mask, g)
randomMask g
gen = (Word32 -> Mask
Mask Word32
int, g
gen')
  where
    (!Word32
int, !g
gen') = forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
gen :: (Word32, g)


--------------------------------------------------------------------------------
-- | Mask a lazy bytestring.  Uses 'c_mask_chunk' under the hood.
maskPayload :: Maybe Mask -> BL.ByteString -> BL.ByteString
maskPayload :: Maybe Mask -> ByteString -> ByteString
maskPayload Maybe Mask
Nothing            = forall a. a -> a
id
maskPayload (Just (Mask Word32
0))    = forall a. a -> a
id
maskPayload (Just (Mask Word32
mask)) = Int -> ByteString -> ByteString
go Int
0
  where
    go :: Int -> ByteString -> ByteString
go Int
_           ByteString
BL.Empty                               = ByteString
BL.Empty
    go !Int
maskOffset (BL.Chunk (B.PS ForeignPtr Word8
payload Int
off Int
len) ByteString
rest) =
        ByteString -> ByteString -> ByteString
BL.Chunk ByteString
maskedChunk (Int -> ByteString -> ByteString
go ((Int
maskOffset forall a. Num a => a -> a -> a
+ Int
len) forall a. Integral a => a -> a -> a
`rem` Int
4) ByteString
rest)
      where
        maskedChunk :: ByteString
maskedChunk =
            Int -> (Ptr Word8 -> IO ()) -> ByteString
B.unsafeCreate Int
len forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
            forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
payload forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
                Word32 -> CInt -> Ptr CChar -> CSize -> Ptr Word8 -> IO ()
c_mask_chunk Word32
mask
                    (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maskOffset)
                    (Ptr Word8
src forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)
                    (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                    Ptr Word8
dst