{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Proto3.Wire.Decode
(
ParsedField(..)
, decodeWire
, Parser(..)
, RawPrimitive
, RawField
, RawMessage
, ParseError(..)
, foldFields
, parse
, bool
, int32
, int64
, uint32
, uint64
, sint32
, sint64
, enum
, byteString
, lazyByteString
, text
, packedVarints
, packedFixed32
, packedFixed64
, packedFloats
, packedDoubles
, fixed32
, fixed64
, sfixed32
, sfixed64
, float
, double
, at
, oneof
, one
, repeated
, embedded
, embedded'
) where
import Control.Applicative
import Control.Arrow (first)
import Control.Exception ( Exception )
import Control.Monad ( msum, foldM )
import Data.Bits
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Foldable ( foldl' )
import qualified Data.IntMap.Strict as M
import Data.Maybe ( fromMaybe )
import Data.Monoid ( (<>) )
import Data.Serialize.Get ( Get, getWord8, getInt32le
, getInt64le, getWord32le, getWord64le
, runGet )
import Data.Serialize.IEEE754 ( getFloat32le, getFloat64le )
import Data.Text.Lazy ( Text, pack )
import Data.Text.Lazy.Encoding ( decodeUtf8' )
import qualified Data.Traversable as T
import Data.Int ( Int32, Int64 )
import Data.Word ( Word8, Word32, Word64 )
import Proto3.Wire.Types
import qualified Safe
zigZagDecode :: (Num a, Bits a) => a -> a
zigZagDecode i = shiftR i 1 `xor` (-(i .&. 1))
data ParsedField = VarintField Word64
| Fixed32Field B.ByteString
| Fixed64Field B.ByteString
| LengthDelimitedField B.ByteString
deriving (Show, Eq)
toMap :: [(FieldNumber, v)] -> M.IntMap [v]
toMap kvs0 = M.fromListWith (<>) . map (fmap (:[])) . map (first (fromIntegral . getFieldNumber)) $ kvs0
decodeWire :: B.ByteString -> Either String [(FieldNumber, ParsedField)]
decodeWire bstr = drloop bstr []
where
drloop !bs xs | B.null bs = Right $ reverse xs
drloop !bs xs | otherwise = do
(w, rest) <- takeVarInt bs
wt <- gwireType $ fromIntegral (w .&. 7)
let fn = w `shiftR` 3
(res, rest2) <- takeWT wt rest
drloop rest2 ((FieldNumber fn,res):xs)
eitherUncons :: B.ByteString -> Either String (Word8, B.ByteString)
eitherUncons = maybe (Left "failed to parse varint128") Right . B.uncons
takeVarInt :: B.ByteString -> Either String (Word64, B.ByteString)
takeVarInt !bs =
case B.uncons bs of
Nothing -> Right (0, B.empty)
Just (w1, r1) -> do
if w1 < 128 then return (fromIntegral w1, r1) else do
let val1 = fromIntegral (w1 - 0x80)
(w2,r2) <- eitherUncons r1
if w2 < 128 then return (val1 + (fromIntegral w2 `shiftL` 7), r2) else do
let val2 = (val1 + (fromIntegral (w2 - 0x80) `shiftL` 7))
(w3,r3) <- eitherUncons r2
if w3 < 128 then return (val2 + (fromIntegral w3 `shiftL` 14), r3) else do
let val3 = (val2 + (fromIntegral (w3 - 0x80) `shiftL` 14))
(w4,r4) <- eitherUncons r3
if w4 < 128 then return (val3 + (fromIntegral w4 `shiftL` 21), r4) else do
let val4 = (val3 + (fromIntegral (w4 - 0x80) `shiftL` 21))
(w5,r5) <- eitherUncons r4
if w5 < 128 then return (val4 + (fromIntegral w5 `shiftL` 28), r5) else do
let val5 = (val4 + (fromIntegral (w5 - 0x80) `shiftL` 28))
(w6,r6) <- eitherUncons r5
if w6 < 128 then return (val5 + (fromIntegral w6 `shiftL` 35), r6) else do
let val6 = (val5 + (fromIntegral (w6 - 0x80) `shiftL` 35))
(w7,r7) <- eitherUncons r6
if w7 < 128 then return (val6 + (fromIntegral w7 `shiftL` 42), r7) else do
let val7 = (val6 + (fromIntegral (w7 - 0x80) `shiftL` 42))
(w8,r8) <- eitherUncons r7
if w8 < 128 then return (val7 + (fromIntegral w8 `shiftL` 49), r8) else do
let val8 = (val7 + (fromIntegral (w8 - 0x80) `shiftL` 49))
(w9,r9) <- eitherUncons r8
if w9 < 128 then return (val8 + (fromIntegral w9 `shiftL` 56), r9) else do
let val9 = (val8 + (fromIntegral (w9 - 0x80) `shiftL` 56))
(w10,r10) <- eitherUncons r9
if w10 < 128 then return (val9 + (fromIntegral w10 `shiftL` 63), r10) else do
Left ("failed to parse varint128: too big; " ++ show val6)
gwireType :: Word8 -> Either String WireType
gwireType 0 = return Varint
gwireType 5 = return Fixed32
gwireType 1 = return Fixed64
gwireType 2 = return LengthDelimited
gwireType wt = Left $ "wireType got unknown wire type: " ++ show wt
safeSplit :: Int -> B.ByteString -> Either String (B.ByteString, B.ByteString)
safeSplit !i! b | B.length b < i = Left "failed to parse varint128: not enough bytes"
| otherwise = Right $ B.splitAt i b
takeWT :: WireType -> B.ByteString -> Either String (ParsedField, B.ByteString)
takeWT Varint !b = fmap (first VarintField) $ takeVarInt b
takeWT Fixed32 !b = fmap (first Fixed32Field) $ safeSplit 4 b
takeWT Fixed64 !b = fmap (first Fixed64Field) $ safeSplit 8 b
takeWT LengthDelimited b = do
(!len, rest) <- takeVarInt b
fmap (first LengthDelimitedField) $ safeSplit (fromIntegral len) rest
data ParseError =
WireTypeError Text
|
BinaryError Text
|
EmbeddedError Text
(Maybe ParseError)
deriving (Show, Eq, Ord)
instance Exception ParseError
newtype Parser input a = Parser { runParser :: input -> Either ParseError a }
deriving Functor
instance Applicative (Parser input) where
pure = Parser . const . pure
Parser p1 <*> Parser p2 =
Parser $ \input -> p1 input <*> p2 input
instance Monad (Parser input) where
Parser p >>= f = Parser $ \input -> p input >>= (`runParser` input) . f
type RawPrimitive = ParsedField
type RawField = [RawPrimitive]
type RawMessage = M.IntMap RawField
foldFields :: M.IntMap (Parser RawPrimitive a, a -> acc -> acc)
-> acc
-> [(FieldNumber, ParsedField)]
-> Either ParseError acc
foldFields parsers = foldM applyOne
where applyOne acc (fn, field) =
case M.lookup (fromIntegral . getFieldNumber $ fn) parsers of
Nothing -> pure acc
Just (parser, apply) ->
case runParser parser field of
Left err -> Left err
Right a -> pure $ apply a acc
parse :: Parser RawMessage a -> B.ByteString -> Either ParseError a
parse parser bs = case decodeWire bs of
Left err -> Left (BinaryError (pack err))
Right res -> runParser parser (toMap res)
parsedField :: RawField -> Maybe RawPrimitive
parsedField xs = case xs of
[] -> Nothing
(x:_) -> Just x
throwWireTypeError :: Show input
=> String
-> input
-> Either ParseError expected
throwWireTypeError expected wrong =
Left (WireTypeError (pack msg))
where
msg = "Wrong wiretype. Expected " ++ expected ++ " but got " ++ show wrong
throwCerealError :: String -> String -> Either ParseError a
throwCerealError expected cerealErr =
Left (BinaryError (pack msg))
where
msg = "Failed to parse contents of " ++
expected ++ " field. " ++ "Error from cereal was: " ++ cerealErr
parseVarInt :: Integral a => Parser RawPrimitive a
parseVarInt = Parser $
\case
VarintField i -> Right (fromIntegral i)
wrong -> throwWireTypeError "varint" wrong
runGetPacked :: Get a -> Parser RawPrimitive a
runGetPacked g = Parser $
\case
LengthDelimitedField bs ->
case runGet g bs of
Left e -> throwCerealError "packed repeated field" e
Right xs -> return xs
wrong -> throwWireTypeError "packed repeated field" wrong
runGetFixed32 :: Get a -> Parser RawPrimitive a
runGetFixed32 g = Parser $
\case
Fixed32Field bs -> case runGet g bs of
Left e -> throwCerealError "fixed32 field" e
Right x -> return x
wrong -> throwWireTypeError "fixed 32 field" wrong
runGetFixed64 :: Get a -> Parser RawPrimitive a
runGetFixed64 g = Parser $
\case
Fixed64Field bs -> case runGet g bs of
Left e -> throwCerealError "fixed 64 field" e
Right x -> return x
wrong -> throwWireTypeError "fixed 64 field" wrong
bytes :: Parser RawPrimitive B.ByteString
bytes = Parser $
\case
LengthDelimitedField bs ->
return $! B.copy bs
wrong -> throwWireTypeError "bytes" wrong
bool :: Parser RawPrimitive Bool
bool = fmap (Safe.toEnumDef False) parseVarInt
int32 :: Parser RawPrimitive Int32
int32 = parseVarInt
int64 :: Parser RawPrimitive Int64
int64 = parseVarInt
uint32 :: Parser RawPrimitive Word32
uint32 = parseVarInt
uint64 :: Parser RawPrimitive Word64
uint64 = parseVarInt
sint32 :: Parser RawPrimitive Int32
sint32 = fmap (fromIntegral . (zigZagDecode :: Word32 -> Word32)) parseVarInt
sint64 :: Parser RawPrimitive Int64
sint64 = fmap (fromIntegral . (zigZagDecode :: Word64 -> Word64)) parseVarInt
byteString :: Parser RawPrimitive B.ByteString
byteString = bytes
lazyByteString :: Parser RawPrimitive BL.ByteString
lazyByteString = fmap BL.fromStrict bytes
text :: Parser RawPrimitive Text
text = Parser $
\case
LengthDelimitedField bs ->
case decodeUtf8' $ BL.fromStrict bs of
Left err -> Left (BinaryError (pack ("Failed to decode UTF-8: " ++
show err)))
Right txt -> return txt
wrong -> throwWireTypeError "string" wrong
enum :: forall e. (Enum e, Bounded e) => Parser RawPrimitive (Either Int e)
enum = fmap toEither parseVarInt
where
toEither :: Int -> Either Int e
toEither i
| Just e <- Safe.toEnumMay i = Right e
| otherwise = Left i
packedVarints :: Integral a => Parser RawPrimitive [a]
packedVarints = fmap (fmap fromIntegral) (runGetPacked (many getBase128Varint))
getBase128Varint :: Get Word64
getBase128Varint = loop 0 0
where
loop !i !w64 = do
w8 <- getWord8
if base128Terminal w8
then return $ combine i w64 w8
else loop (i + 1) (combine i w64 w8)
base128Terminal w8 = (not . (`testBit` 7)) $ w8
combine i w64 w8 = (w64 .|.
(fromIntegral (w8 `clearBit` 7)
`shiftL`
(i * 7)))
packedFloats :: Parser RawPrimitive [Float]
packedFloats = runGetPacked (many getFloat32le)
packedDoubles :: Parser RawPrimitive [Double]
packedDoubles = runGetPacked (many getFloat64le)
packedFixed32 :: Integral a => Parser RawPrimitive [a]
packedFixed32 = fmap (fmap fromIntegral) (runGetPacked (many getWord32le))
packedFixed64 :: Integral a => Parser RawPrimitive [a]
packedFixed64 = fmap (fmap fromIntegral) (runGetPacked (many getWord64le))
float :: Parser RawPrimitive Float
float = runGetFixed32 getFloat32le
double :: Parser RawPrimitive Double
double = runGetFixed64 getFloat64le
fixed32 :: Parser RawPrimitive Word32
fixed32 = runGetFixed32 getWord32le
fixed64 :: Parser RawPrimitive Word64
fixed64 = runGetFixed64 getWord64le
sfixed32 :: Parser RawPrimitive Int32
sfixed32 = runGetFixed32 getInt32le
sfixed64 :: Parser RawPrimitive Int64
sfixed64 = runGetFixed64 getInt64le
at :: Parser RawField a -> FieldNumber -> Parser RawMessage a
at parser fn = Parser $ runParser parser . fromMaybe mempty . M.lookup (fromIntegral . getFieldNumber $ fn)
oneof :: a
-> [(FieldNumber, Parser RawField a)]
-> Parser RawMessage a
oneof def parsersByFieldNum = Parser $ \input ->
case msum ((\(num,p) -> (p,) <$> M.lookup (fromIntegral . getFieldNumber $ num) input) <$> parsersByFieldNum) of
Nothing -> pure def
Just (p, v) -> runParser p v
one :: Parser RawPrimitive a -> a -> Parser RawField a
one parser def = Parser (fmap (fromMaybe def) . traverse (runParser parser) . parsedField)
repeated :: Parser RawPrimitive a -> Parser RawField [a]
repeated parser = Parser $ fmap reverse . mapM (runParser parser)
embeddedToParsedFields :: RawPrimitive -> Either ParseError RawMessage
embeddedToParsedFields (LengthDelimitedField bs) =
case decodeWire bs of
Left err -> Left (EmbeddedError ("Failed to parse embedded message: "
<> (pack err))
Nothing)
Right result -> return (toMap result)
embeddedToParsedFields wrong =
throwWireTypeError "embedded" wrong
embedded :: Parser RawMessage a -> Parser RawField (Maybe a)
embedded p = Parser $
\xs -> if xs == empty
then return Nothing
else do
innerMaps <- T.mapM embeddedToParsedFields xs
let combinedMap = foldl' (M.unionWith (<>)) M.empty innerMaps
parsed <- runParser p combinedMap
return $ Just parsed
embedded' :: Parser RawMessage a -> Parser RawPrimitive a
embedded' parser = Parser $
\case
LengthDelimitedField bs ->
case parse parser bs of
Left err -> Left (EmbeddedError "Failed to parse embedded message."
(Just err))
Right result -> return result
wrong -> throwWireTypeError "embedded" wrong