{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
#include "MachDeps.h"
#endif
module Data.Serialize.Get (
Get
, runGet
, runGetLazy
, runGetState
, runGetLazyState
, Result(..)
, runGetPartial
, runGetChunk
, ensure
, isolate
, label
, skip
, uncheckedSkip
, lookAhead
, lookAheadM
, lookAheadE
, uncheckedLookAhead
, getBytes
, remaining
, isEmpty
, getWord8
, getInt8
, getByteString
, getLazyByteString
, getShortByteString
, getWord16be
, getWord32be
, getWord64be
, getInt16be
, getInt32be
, getInt64be
, getWord16le
, getWord32le
, getWord64le
, getInt16le
, getInt32le
, getInt64le
, getWordhost
, getWord16host
, getWord32host
, getWord64host
, getTwoOf
, getListOf
, getIArrayOf
, getTreeOf
, getSeqOf
, getMapOf
, getIntMapOf
, getSetOf
, getIntSetOf
, getMaybeOf
, getEitherOf
, getNested
) where
import qualified Control.Applicative as A
import qualified Control.Monad as M
import Control.Monad (unless)
import qualified Control.Monad.Fail as Fail
import Data.Array.IArray (IArray,listArray)
import Data.Ix (Ix)
import Data.List (intercalate)
import Data.Maybe (isNothing,fromMaybe)
import Foreign
import System.IO.Unsafe (unsafeDupablePerformIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Short as BS
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import qualified Data.Map as Map
import qualified Data.Sequence as Seq
import qualified Data.Set as Set
import qualified Data.Tree as T
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
import GHC.Base
import GHC.Word
#endif
data Result r = Fail String B.ByteString
| Partial (B.ByteString -> Result r)
| Done r B.ByteString
instance Show r => Show (Result r) where
show (Fail msg _) = "Fail " ++ show msg
show (Partial _) = "Partial _"
show (Done r bs) = "Done " ++ show r ++ " " ++ show bs
instance Functor Result where
fmap _ (Fail msg rest) = Fail msg rest
fmap f (Partial k) = Partial (fmap f . k)
fmap f (Done r bs) = Done (f r) bs
newtype Get a = Get
{ unGet :: forall r. Input -> Buffer -> More
-> Failure r -> Success a r
-> Result r }
type Input = B.ByteString
type Buffer = Maybe B.ByteString
emptyBuffer :: Buffer
emptyBuffer = Just B.empty
extendBuffer :: Buffer -> B.ByteString -> Buffer
extendBuffer buf chunk =
do bs <- buf
return $! bs `B.append` chunk
{-# INLINE extendBuffer #-}
append :: Buffer -> Buffer -> Buffer
append l r = B.append `fmap` l A.<*> r
{-# INLINE append #-}
bufferBytes :: Buffer -> B.ByteString
bufferBytes = fromMaybe B.empty
{-# INLINE bufferBytes #-}
type Failure r = Input -> Buffer -> More -> [String] -> String -> Result r
type Success a r = Input -> Buffer -> More -> a -> Result r
data More
= Complete
| Incomplete (Maybe Int)
deriving (Eq)
moreLength :: More -> Int
moreLength m = case m of
Complete -> 0
Incomplete mb -> fromMaybe 0 mb
instance Functor Get where
fmap p m = Get $ \ s0 b0 m0 kf ks ->
unGet m s0 b0 m0 kf $ \ s1 b1 m1 a -> ks s1 b1 m1 (p a)
instance A.Applicative Get where
pure a = Get $ \ s0 b0 m0 _ ks -> ks s0 b0 m0 a
{-# INLINE pure #-}
f <*> x = Get $ \ s0 b0 m0 kf ks ->
unGet f s0 b0 m0 kf $ \ s1 b1 m1 g ->
unGet x s1 b1 m1 kf $ \ s2 b2 m2 y -> ks s2 b2 m2 (g y)
{-# INLINE (<*>) #-}
m *> k = Get $ \ s0 b0 m0 kf ks ->
unGet m s0 b0 m0 kf $ \ s1 b1 m1 _ -> unGet k s1 b1 m1 kf ks
{-# INLINE (*>) #-}
instance A.Alternative Get where
empty = failDesc "empty"
{-# INLINE empty #-}
(<|>) = M.mplus
{-# INLINE (<|>) #-}
instance Monad Get where
return = A.pure
{-# INLINE return #-}
m >>= g = Get $ \ s0 b0 m0 kf ks ->
unGet m s0 b0 m0 kf $ \ s1 b1 m1 a -> unGet (g a) s1 b1 m1 kf ks
{-# INLINE (>>=) #-}
(>>) = (A.*>)
{-# INLINE (>>) #-}
fail = Fail.fail
{-# INLINE fail #-}
instance Fail.MonadFail Get where
fail = failDesc
{-# INLINE fail #-}
instance M.MonadPlus Get where
mzero = failDesc "mzero"
{-# INLINE mzero #-}
mplus a b =
Get $ \s0 b0 m0 kf ks ->
let ks' s1 b1 = ks s1 (b0 `append` b1)
kf' _ b1 m1 = kf (s0 `B.append` bufferBytes b1)
(b0 `append` b1) m1
try _ b1 m1 _ _ = unGet b (s0 `B.append` bufferBytes b1)
b1 m1 kf' ks'
in unGet a s0 emptyBuffer m0 try ks'
{-# INLINE mplus #-}
formatTrace :: [String] -> String
formatTrace [] = "Empty call stack"
formatTrace ls = "From:\t" ++ intercalate "\n\t" ls ++ "\n"
get :: Get B.ByteString
get = Get (\s0 b0 m0 _ k -> k s0 b0 m0 s0)
{-# INLINE get #-}
put :: B.ByteString -> Get ()
put s = Get (\_ b0 m _ k -> k s b0 m ())
{-# INLINE put #-}
label :: String -> Get a -> Get a
label l m =
Get $ \ s0 b0 m0 kf ks ->
let kf' s1 b1 m1 ls = kf s1 b1 m1 (l:ls)
in unGet m s0 b0 m0 kf' ks
finalK :: Success a a
finalK s _ _ a = Done a s
failK :: Failure a
failK s b _ ls msg =
Fail (unlines [msg, formatTrace ls]) (s `B.append` bufferBytes b)
runGet :: Get a -> B.ByteString -> Either String a
runGet m str =
case unGet m str Nothing Complete failK finalK of
Fail i _ -> Left i
Done a _ -> Right a
Partial{} -> Left "Failed reading: Internal error: unexpected Partial."
{-# INLINE runGet #-}
runGetChunk :: Get a -> Maybe Int -> B.ByteString -> Result a
runGetChunk m mbLen str = unGet m str Nothing (Incomplete mbLen) failK finalK
{-# INLINE runGetChunk #-}
runGetPartial :: Get a -> B.ByteString -> Result a
runGetPartial m = runGetChunk m Nothing
{-# INLINE runGetPartial #-}
runGetState :: Get a -> B.ByteString -> Int
-> Either String (a, B.ByteString)
runGetState m str off = case runGetState' m str off of
(Right a,bs) -> Right (a,bs)
(Left i,_) -> Left i
{-# INLINE runGetState #-}
runGetState' :: Get a -> B.ByteString -> Int
-> (Either String a, B.ByteString)
runGetState' m str off =
case unGet m (B.drop off str) Nothing Complete failK finalK of
Fail i bs -> (Left i,bs)
Done a bs -> (Right a, bs)
Partial{} -> (Left "Failed reading: Internal error: unexpected Partial.",B.empty)
{-# INLINE runGetState' #-}
runGetLazy' :: Get a -> L.ByteString -> (Either String a,L.ByteString)
runGetLazy' m lstr =
case L.toChunks lstr of
[c] -> wrapStrict (runGetState' m c 0)
[] -> wrapStrict (runGetState' m B.empty 0)
c:cs -> loop (runGetChunk m (Just (len - B.length c)) c) cs
where
len = fromIntegral (L.length lstr)
wrapStrict (e,s) = (e,L.fromChunks [s])
loop result chunks = case result of
Fail str rest -> (Left str, L.fromChunks (rest : chunks))
Partial k -> case chunks of
c:cs -> loop (k c) cs
[] -> loop (k B.empty) []
Done r rest -> (Right r, L.fromChunks (rest : chunks))
{-# INLINE runGetLazy' #-}
runGetLazy :: Get a -> L.ByteString -> Either String a
runGetLazy m lstr = fst (runGetLazy' m lstr)
{-# INLINE runGetLazy #-}
runGetLazyState :: Get a -> L.ByteString -> Either String (a,L.ByteString)
runGetLazyState m lstr = case runGetLazy' m lstr of
(Right a,rest) -> Right (a,rest)
(Left err,_) -> Left err
{-# INLINE runGetLazyState #-}
{-# INLINE ensure #-}
ensure :: Int -> Get B.ByteString
ensure n0 = n0 `seq` Get $ \ s0 b0 m0 kf ks -> let
n' = n0 - B.length s0
in if n' <= 0
then ks s0 b0 m0 s0
else getMore n' s0 [] b0 m0 kf ks
where
finalInput s0 ss = B.concat (reverse (s0 : ss))
finalBuffer b0 s0 ss = extendBuffer b0 (B.concat (reverse (init (s0 : ss))))
getMore !n s0 ss b0 m0 kf ks = let
tooFewBytes = let
!s = finalInput s0 ss
!b = finalBuffer b0 s0 ss
in kf s b m0 ["demandInput"] "too few bytes"
in case m0 of
Complete -> tooFewBytes
Incomplete mb -> Partial $ \s ->
if B.null s
then tooFewBytes
else let
!mb' = case mb of
Just l -> Just $! l - B.length s
Nothing -> Nothing
in checkIfEnough n s (s0 : ss) b0 (Incomplete mb') kf ks
checkIfEnough !n s0 ss b0 m0 kf ks = let
n' = n - B.length s0
in if n' <= 0
then let
!s = finalInput s0 ss
!b = finalBuffer b0 s0 ss
in ks s b m0 s
else getMore n' s0 ss b0 m0 kf ks
isolate :: Int -> Get a -> Get a
isolate n m = do
M.when (n < 0) (fail "Attempted to isolate a negative number of bytes")
s <- ensure n
let (s',rest) = B.splitAt n s
put s'
a <- m
used <- get
unless (B.null used) (fail "not all bytes parsed in isolate")
put rest
return a
failDesc :: String -> Get a
failDesc err = do
let msg = "Failed reading: " ++ err
Get (\s0 b0 m0 kf _ -> kf s0 b0 m0 [] msg)
skip :: Int -> Get ()
skip n = do
s <- ensure n
put (B.drop n s)
uncheckedSkip :: Int -> Get ()
uncheckedSkip n = do
s <- get
put (B.drop n s)
lookAhead :: Get a -> Get a
lookAhead ga = Get $ \ s0 b0 m0 kf ks ->
let ks' _ b1 = ks (s0 `B.append` bufferBytes b1) (b0 `append` b1)
kf' _ b1 = kf s0 (b0 `append` b1)
in unGet ga s0 emptyBuffer m0 kf' ks'
lookAheadM :: Get (Maybe a) -> Get (Maybe a)
lookAheadM gma = do
s <- get
ma <- gma
M.when (isNothing ma) (put s)
return ma
lookAheadE :: Get (Either a b) -> Get (Either a b)
lookAheadE gea = do
s <- get
ea <- gea
case ea of
Left _ -> put s
_ -> return ()
return ea
uncheckedLookAhead :: Int -> Get B.ByteString
uncheckedLookAhead n = do
s <- get
return (B.take n s)
remaining :: Get Int
remaining = Get (\ s0 b0 m0 _ ks -> ks s0 b0 m0 (B.length s0 + moreLength m0))
isEmpty :: Get Bool
isEmpty = Get (\ s0 b0 m0 _ ks -> ks s0 b0 m0 (B.null s0 && moreLength m0 == 0))
getByteString :: Int -> Get B.ByteString
getByteString n = do
bs <- getBytes n
return $! B.copy bs
getLazyByteString :: Int64 -> Get L.ByteString
getLazyByteString n = f `fmap` getByteString (fromIntegral n)
where f bs = L.fromChunks [bs]
getShortByteString :: Int -> Get BS.ShortByteString
getShortByteString n = do
bs <- getBytes n
return $! BS.toShort bs
getBytes :: Int -> Get B.ByteString
getBytes n | n < 0 = fail "getBytes: negative length requested"
getBytes n = do
s <- ensure n
let consume = B.unsafeTake n s
rest = B.unsafeDrop n s
put rest
return consume
{-# INLINE getBytes #-}
getPtr :: Storable a => Int -> Get a
getPtr n = do
(fp,o,_) <- B.toForeignPtr `fmap` getBytes n
let k p = peek (castPtr (p `plusPtr` o))
return (unsafeDupablePerformIO (withForeignPtr fp k))
{-# INLINE getPtr #-}
getInt8 :: Get Int8
getInt8 = do
s <- getBytes 1
return $! fromIntegral (B.unsafeHead s)
getInt16be :: Get Int16
getInt16be = do
s <- getBytes 2
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 1) )
getInt16le :: Get Int16
getInt16le = do
s <- getBytes 2
return $! (fromIntegral (s `B.unsafeIndex` 1) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
getInt32be :: Get Int32
getInt32be = do
s <- getBytes 4
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftL` 24) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftL` 16) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 3) )
getInt32le :: Get Int32
getInt32le = do
s <- getBytes 4
return $! (fromIntegral (s `B.unsafeIndex` 3) `shiftL` 24) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftL` 16) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
getInt64be :: Get Int64
getInt64be = do
s <- getBytes 8
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftL` 56) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftL` 48) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftL` 40) .|.
(fromIntegral (s `B.unsafeIndex` 3) `shiftL` 32) .|.
(fromIntegral (s `B.unsafeIndex` 4) `shiftL` 24) .|.
(fromIntegral (s `B.unsafeIndex` 5) `shiftL` 16) .|.
(fromIntegral (s `B.unsafeIndex` 6) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 7) )
getInt64le :: Get Int64
getInt64le = do
s <- getBytes 8
return $! (fromIntegral (s `B.unsafeIndex` 7) `shiftL` 56) .|.
(fromIntegral (s `B.unsafeIndex` 6) `shiftL` 48) .|.
(fromIntegral (s `B.unsafeIndex` 5) `shiftL` 40) .|.
(fromIntegral (s `B.unsafeIndex` 4) `shiftL` 32) .|.
(fromIntegral (s `B.unsafeIndex` 3) `shiftL` 24) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftL` 16) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftL` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
{-# INLINE getInt8 #-}
{-# INLINE getInt16be #-}
{-# INLINE getInt16le #-}
{-# INLINE getInt32be #-}
{-# INLINE getInt32le #-}
{-# INLINE getInt64be #-}
{-# INLINE getInt64le #-}
getWord8 :: Get Word8
getWord8 = do
s <- getBytes 1
return (B.unsafeHead s)
getWord16be :: Get Word16
getWord16be = do
s <- getBytes 2
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftl_w16` 8) .|.
(fromIntegral (s `B.unsafeIndex` 1))
getWord16le :: Get Word16
getWord16le = do
s <- getBytes 2
return $! (fromIntegral (s `B.unsafeIndex` 1) `shiftl_w16` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
getWord32be :: Get Word32
getWord32be = do
s <- getBytes 4
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftl_w32` 24) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftl_w32` 16) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftl_w32` 8) .|.
(fromIntegral (s `B.unsafeIndex` 3) )
getWord32le :: Get Word32
getWord32le = do
s <- getBytes 4
return $! (fromIntegral (s `B.unsafeIndex` 3) `shiftl_w32` 24) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftl_w32` 16) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftl_w32` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
getWord64be :: Get Word64
getWord64be = do
s <- getBytes 8
return $! (fromIntegral (s `B.unsafeIndex` 0) `shiftl_w64` 56) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftl_w64` 48) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftl_w64` 40) .|.
(fromIntegral (s `B.unsafeIndex` 3) `shiftl_w64` 32) .|.
(fromIntegral (s `B.unsafeIndex` 4) `shiftl_w64` 24) .|.
(fromIntegral (s `B.unsafeIndex` 5) `shiftl_w64` 16) .|.
(fromIntegral (s `B.unsafeIndex` 6) `shiftl_w64` 8) .|.
(fromIntegral (s `B.unsafeIndex` 7) )
getWord64le :: Get Word64
getWord64le = do
s <- getBytes 8
return $! (fromIntegral (s `B.unsafeIndex` 7) `shiftl_w64` 56) .|.
(fromIntegral (s `B.unsafeIndex` 6) `shiftl_w64` 48) .|.
(fromIntegral (s `B.unsafeIndex` 5) `shiftl_w64` 40) .|.
(fromIntegral (s `B.unsafeIndex` 4) `shiftl_w64` 32) .|.
(fromIntegral (s `B.unsafeIndex` 3) `shiftl_w64` 24) .|.
(fromIntegral (s `B.unsafeIndex` 2) `shiftl_w64` 16) .|.
(fromIntegral (s `B.unsafeIndex` 1) `shiftl_w64` 8) .|.
(fromIntegral (s `B.unsafeIndex` 0) )
{-# INLINE getWord8 #-}
{-# INLINE getWord16be #-}
{-# INLINE getWord16le #-}
{-# INLINE getWord32be #-}
{-# INLINE getWord32le #-}
{-# INLINE getWord64be #-}
{-# INLINE getWord64le #-}
getWordhost :: Get Word
getWordhost = getPtr (sizeOf (undefined :: Word))
getWord16host :: Get Word16
getWord16host = getPtr (sizeOf (undefined :: Word16))
getWord32host :: Get Word32
getWord32host = getPtr (sizeOf (undefined :: Word32))
getWord64host :: Get Word64
getWord64host = getPtr (sizeOf (undefined :: Word64))
shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
shiftl_w16 (W16# w) (I# i) = W16# (w `uncheckedShiftL#` i)
shiftl_w32 (W32# w) (I# i) = W32# (w `uncheckedShiftL#` i)
#if WORD_SIZE_IN_BITS < 64
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL64#` i)
#if __GLASGOW_HASKELL__ <= 606
foreign import ccall unsafe "stg_uncheckedShiftL64"
uncheckedShiftL64# :: Word64# -> Int# -> Word64#
#endif
#else
shiftl_w64 (W64# w) (I# i) = W64# (w `uncheckedShiftL#` i)
#endif
#else
shiftl_w16 = shiftL
shiftl_w32 = shiftL
shiftl_w64 = shiftL
#endif
getTwoOf :: Get a -> Get b -> Get (a,b)
getTwoOf ma mb = M.liftM2 (,) ma mb
getListOf :: Get a -> Get [a]
getListOf m = go [] =<< getWord64be
where
go as 0 = return $! reverse as
go as i = do x <- m
x `seq` go (x:as) (i - 1)
getIArrayOf :: (Ix i, IArray a e) => Get i -> Get e -> Get (a i e)
getIArrayOf ix e = M.liftM2 listArray (getTwoOf ix ix) (getListOf e)
getSeqOf :: Get a -> Get (Seq.Seq a)
getSeqOf m = go Seq.empty =<< getWord64be
where
go xs 0 = return $! xs
go xs n = xs `seq` n `seq` do
x <- m
go (xs Seq.|> x) (n - 1)
getTreeOf :: Get a -> Get (T.Tree a)
getTreeOf m = M.liftM2 T.Node m (getListOf (getTreeOf m))
getMapOf :: Ord k => Get k -> Get a -> Get (Map.Map k a)
getMapOf k m = Map.fromList `fmap` getListOf (getTwoOf k m)
getIntMapOf :: Get Int -> Get a -> Get (IntMap.IntMap a)
getIntMapOf i m = IntMap.fromList `fmap` getListOf (getTwoOf i m)
getSetOf :: Ord a => Get a -> Get (Set.Set a)
getSetOf m = Set.fromList `fmap` getListOf m
getIntSetOf :: Get Int -> Get IntSet.IntSet
getIntSetOf m = IntSet.fromList `fmap` getListOf m
getMaybeOf :: Get a -> Get (Maybe a)
getMaybeOf m = do
tag <- getWord8
case tag of
0 -> return Nothing
_ -> Just `fmap` m
getEitherOf :: Get a -> Get b -> Get (Either a b)
getEitherOf ma mb = do
tag <- getWord8
case tag of
0 -> Left `fmap` ma
_ -> Right `fmap` mb
getNested :: Get Int -> Get a -> Get a
getNested getLen getVal = do
n <- getLen
isolate n getVal