{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Flat.Decoder.Prim (
dBool,
dWord8,
dBE8,
dBE16,
dBE32,
dBE64,
dBEBits8,
dBEBits16,
dBEBits32,
dBEBits64,
dropBits,
dFloat,
dDouble,
getChunksInfo,
dByteString_,
dLazyByteString_,
dByteArray_,
ConsState(..),consOpen,consClose,consBool,consBits
) where
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Flat.Decoder.Types
import Flat.Endian
import Flat.Memory
import Data.FloatCast
import Data.Word
import Foreign
data ConsState =
ConsState {-# UNPACK #-} !Word !Int
consOpen :: Get ConsState
consOpen = Get $ \endPtr s -> do
let u = usedBits s
w <- case compare (currPtr s) endPtr of
LT -> do
w16::Word16 <- toBE16 <$> peek (castPtr $ currPtr s)
return $ fromIntegral w16 `unsafeShiftL` (u+(wordSize-16))
EQ -> do
w8 :: Word8 <- peek (currPtr s)
return $ fromIntegral w8 `unsafeShiftL` (u+(wordSize-8))
GT -> notEnoughSpace endPtr s
return $ GetResult s (ConsState w 0)
consClose :: Int -> Get ()
consClose n = Get $ \endPtr s -> do
let u' = n+usedBits s
if u' < 8
then return $ GetResult (s {usedBits=u'}) ()
else if currPtr s >= endPtr
then notEnoughSpace endPtr s
else return $ GetResult (s {currPtr=currPtr s `plusPtr` 1,usedBits=u'-8}) ()
consBool :: ConsState -> (ConsState,Bool)
consBool cs = (0/=) <$> consBits cs 1
consBits :: ConsState -> Int -> (ConsState, Word)
consBits cs 3 = consBits_ cs 3 7
consBits cs 2 = consBits_ cs 2 3
consBits cs 1 = consBits_ cs 1 1
consBits _ _ = error "unsupported"
consBits_ :: ConsState -> Int -> Word -> (ConsState, Word)
#define CONS_STA
#ifdef CONS_ROT
consBits_ (ConsState w usedBits) numBits mask =
let usedBits' = numBits+usedBits
w' = w `rotateL` numBits
in (ConsState w' usedBits',w' .&. mask)
#endif
#ifdef CONS_SHL
consBits_ (ConsState w usedBits) numBits mask =
let usedBits' = numBits+usedBits
w' = w `unsafeShiftL` numBits
in (ConsState w' usedBits', (w `shR` (wordSize - numBits)) .&. mask)
#endif
#ifdef CONS_STA
consBits_ (ConsState w usedBits) numBits mask =
let usedBits' = numBits+usedBits
in (ConsState w usedBits', (w `shR` (wordSize - usedBits')) .&. mask)
#endif
wordSize :: Int
wordSize = finiteBitSize (0 :: Word)
{-# INLINE ensureBits #-}
ensureBits :: Ptr Word8 -> S -> Int -> IO ()
ensureBits endPtr s n = when ((endPtr `minusPtr` currPtr s) * 8 - usedBits s < n) $ notEnoughSpace endPtr s
{-# INLINE dropBits #-}
dropBits :: Int -> Get ()
dropBits n
| n > 0 = Get $ \endPtr s -> do
ensureBits endPtr s n
return $ GetResult (dropBits_ s n) ()
| n == 0 = return ()
| otherwise = error $ unwords ["dropBits",show n]
{-# INLINE dropBits_ #-}
dropBits_ :: S -> Int -> S
dropBits_ s n =
let (bytes,bits) = (n+usedBits s) `divMod` 8
in S {currPtr=currPtr s `plusPtr` bytes,usedBits=bits}
{-# INLINE dBool #-}
dBool :: Get Bool
dBool = Get $ \endPtr s ->
if currPtr s >= endPtr
then notEnoughSpace endPtr s
else do
!w <- peek (currPtr s)
let !b = 0 /= (w .&. (128 `shR` usedBits s))
let !s' = if usedBits s == 7
then s { currPtr = currPtr s `plusPtr` 1, usedBits = 0 }
else s { usedBits = usedBits s + 1 }
return $ GetResult s' b
{-# INLINE dBEBits8 #-}
dBEBits8 :: Int -> Get Word8
dBEBits8 n = Get $ \endPtr s -> do
ensureBits endPtr s n
take8 s n
{-# INLINE dBEBits16 #-}
dBEBits16 :: Int -> Get Word16
dBEBits16 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE dBEBits32 #-}
dBEBits32 :: Int -> Get Word32
dBEBits32 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE dBEBits64 #-}
dBEBits64 :: Int -> Get Word64
dBEBits64 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE take8 #-}
take8 :: S -> Int -> IO (GetResult Word8)
take8 s n = GetResult (dropBits8 s n) <$> read8 s n
where
read8 :: S -> Int -> IO Word8
read8 s n | n >=0 && n <=8 =
if n <= 8 - usedBits s
then do
w <- peek (currPtr s)
return $ (w `unsafeShiftL` usedBits s) `shR` (8 - n)
else do
w::Word16 <- toBE16 <$> peek (castPtr $ currPtr s)
return $ fromIntegral $ (w `unsafeShiftL` usedBits s) `shR` (16 - n)
| otherwise = error $ unwords ["read8: cannot read",show n,"bits"]
dropBits8 :: S -> Int -> S
dropBits8 s n =
let u' = n+usedBits s
in if u' < 8
then s {usedBits=u'}
else s {currPtr=currPtr s `plusPtr` 1,usedBits=u'-8}
{-# INLINE takeN #-}
takeN :: (Num a, Bits a) => Int -> S -> IO (GetResult a)
takeN n s = read s 0 (n - (n `min` 8)) n
where
read s r sh n | n <=0 = return $ GetResult s r
| otherwise = do
let m = n `min` 8
GetResult s' b <- take8 s m
read s' (r .|. (fromIntegral b `unsafeShiftL` sh)) ((sh-8) `max` 0) (n-8)
dWord8 :: Get Word8
dWord8 = dBE8
{-# INLINE dBE8 #-}
dBE8 :: Get Word8
dBE8 = Get $ \endPtr s -> do
ensureBits endPtr s 8
!w1 <- peek (currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!w2 <- peek (currPtr s `plusPtr` 1)
return $ (w1 `unsafeShiftL` usedBits s) .|. (w2 `shR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 1}) w
{-# INLINE dBE16 #-}
dBE16 :: Get Word16
dBE16 = Get $ \endPtr s -> do
ensureBits endPtr s 16
!w1 <- toBE16 <$> peek (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 2)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `shR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 2}) w
{-# INLINE dBE32 #-}
dBE32 :: Get Word32
dBE32 = Get $ \endPtr s -> do
ensureBits endPtr s 32
!w1 <- toBE32 <$> peek (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 4)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `shR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 4}) w
{-# INLINE dBE64 #-}
dBE64 :: Get Word64
dBE64 = Get $ \endPtr s -> do
ensureBits endPtr s 64
!w1 <- toBE64 <$> peek64 (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 8)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `shR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 8}) w
where
peek64 :: Ptr Word64 -> IO Word64
peek64 = peek
{-# INLINE dFloat #-}
dFloat :: Get Float
dFloat = wordToFloat <$> dBE32
{-# INLINE dDouble #-}
dDouble :: Get Double
dDouble = wordToDouble <$> dBE64
dLazyByteString_ :: Get L.ByteString
dLazyByteString_ = L.fromStrict <$> dByteString_
dByteString_ :: Get B.ByteString
dByteString_ = chunksToByteString <$> getChunksInfo
dByteArray_ :: Get (ByteArray,Int)
dByteArray_ = chunksToByteArray <$> getChunksInfo
getChunksInfo :: Get (Ptr Word8, [Int])
getChunksInfo = Get $ \endPtr s -> do
let getChunks srcPtr l = do
ensureBits endPtr s 8
!n <- fromIntegral <$> peek srcPtr
if n==0
then return (srcPtr `plusPtr` 1,l [])
else do
ensureBits endPtr s ((n+1)*8)
getChunks (srcPtr `plusPtr` (n+1)) (l . (n:))
when (usedBits s /=0) $ badEncoding endPtr s "usedBits /= 0"
(currPtr',ns) <- getChunks (currPtr s) id
return $ GetResult (s {currPtr=currPtr'}) (currPtr s `plusPtr` 1,ns)
{-# INLINE shR #-}
shR :: Bits a => a -> Int -> a
#ifdef ghcjs_HOST_OS
shR val 0 = val
shR val n = shift val (-n)
#else
shR = unsafeShiftR
#endif