{- Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -} {-# LANGUAGE CPP #-} {-# LANGUAGE MagicHash #-} module Thrift.Protocol.Compact ( serializeCompact , deserializeCompact , Compact ) where import Control.Monad import Thrift.Binary.Parser as P import Data.Bits import Data.ByteString (ByteString) import Data.ByteString.Builder as B import Data.ByteString.Builder.Prim as Prim import Data.Int import Data.List import Data.Proxy import Data.Text.Encoding import Data.Word import GHC.Base hiding (build, Type) import GHC.Word import qualified Data.ByteString as BS import qualified Data.HashMap.Strict as HashMap import Thrift.Protocol import Thrift.Protocol.Binary.Internal data Compact serializeCompact :: ThriftStruct a => a -> ByteString serializeCompact = serializeGen (Proxy :: Proxy Compact) deserializeCompact :: ThriftStruct a => ByteString -> Either String a deserializeCompact = deserializeGen (Proxy :: Proxy Compact) -- Macros for ThriftTypes ------------------------------------------------------ #define TSTOP 0x00 #define TTRUE 0x01 #define TFALSE 0x02 #define TBYTE 0x03 #define TI16 0x04 #define TI32 0x05 #define TI64 0x06 #define TDOUBLE 0x07 #define TSTRING 0x08 #define TLIST 0x09 #define TSET 0x0A #define TMAP 0x0B #define TSTRUCT 0x0C #define TFLOAT 0x0D protocolVersion :: Word8 protocolVersion = 0x02 protocolID :: Word8 protocolID = 0x82 typeMask :: Word8 typeMask = 0xE0 versionMask :: Word8 versionMask = 0x1F shiftAmt :: Int shiftAmt = 5 -- BinaryWriter instance ------------------------------------------------------- instance Protocol Compact where -- Generators for Types getByteType _ = TBYTE getI16Type _ = TI16 getI32Type _ = TI32 getI64Type _ = TI64 getFloatType _ = TFLOAT getDoubleType _ = TDOUBLE getBoolType _ = TTRUE getStringType _ = TSTRING getListType _ = TLIST getSetType _ = TSET getMapType _ = TMAP getStructType _ = TSTRUCT -- Generators for tokens genMsgBegin proxy name msgType seqNum = primBounded (liftFixedToBounded Prim.word8 >*< liftFixedToBounded Prim.word8 >*< compactI32) (protocolID, (protocolVersion .|. ((msgType `shiftL` shiftAmt) .&. typeMask), seqNum)) <> genText proxy name genStruct _ fields = mconcat fields <> B.word8 TSTOP -- NOTE: This will not work correctly for generating boolean fields. Boolean -- fields MUST be generated using genFieldPrim genField _ _ ty fid lastId val = primBounded (compactFieldBegin ty) (fid, lastId) <> val genFieldPrim _ _ ty fid lastId prim val = primBounded (compactFieldBegin ty >*< prim) ((fid, lastId), val) genFieldBool _ _ fid lastId val = primBounded (compactFieldBegin ty) (fid, lastId) where ty = if val then TTRUE else TFALSE genList _ ty build elems = genListBegin ty elems <> mconcat (map build elems) genListPrim _ ty bounded elems = genListBegin ty elems <> primMapListBounded bounded elems genMap _ kt vt _ kbuild vbuild elems = genMapBegin kt vt elems <> mconcat (map (\(k, v) -> kbuild k <> vbuild v) elems) genMapPrim _ kt vt _ kbuild vbuild elems = genMapBegin kt vt elems <> primMapListBounded (kbuild >*< vbuild) elems -- Generators for base types genBytePrim _ = liftFixedToBounded Prim.int8 genI16Prim _ = compactI16 genI32Prim _ = compactI32 genI64Prim _ = compactI64 genFloat _ = B.floatBE genDouble _ = B.doubleBE genBoolPrim _ = liftFixedToBounded $ getVal >$< Prim.word8 where getVal b = if b then TTRUE else TFALSE genByteString _ s = primBounded buildVarint (W# (int2Word# len)) <> byteString s where !(I# len) = BS.length s -- Parsers for tokens parseMsgBegin proxy = do pid <- P.anyWord8 when (pid /= protocolID) $ fail "parseMsgBegin: invalid protocol id" versionAndType <- P.anyWord8 let version = versionAndType .&. versionMask msgType = (versionAndType .&. typeMask) `shiftR` shiftAmt when (version /= protocolVersion) $ fail "parseMsgBegin: invalid protocol version" seqNum <- parseCompactI32 name <- parseText proxy return $ MsgBegin name msgType seqNum parseFieldBegin _ lastId _ = do byte <- anyWord8 let rawType = byte .&. 0x0F if rawType == TSTOP then pure FieldEnd else do -- Boolean Value is encoded in the type let (!ty, !bool) = case rawType of 0x01 -> (TTRUE, True) 0x02 -> (TTRUE, False) mask -> (mask, False) -- Field Id depends on previous field Id fid <- case byte `shiftR` 4 of 0x00 -> parseCompactI16 offs -> pure $ fromIntegral offs + lastId return $ FieldBegin ty fid bool parseList _ p = do byte <- anyWord8 let _ty = byte .&. 0x0F len <- case byte `shiftR` 4 of 0x0F -> parseVarint size -> pure $ fromIntegral size ps <- replicateM (fromIntegral len) p return (fromIntegral len, ps) parseMap _ pk pv _ = do len <- parseVarint if len == 0 then pure [] else do byte <- anyWord8 let _ktype = byte `shiftR` 4 _vtype = byte .&. 0x0F replicateM (fromIntegral len) $ (,) <$> pk <*> pv -- Parser for base types parseByte _ = fromIntegral <$> anyWord8 parseI16 _ = parseCompactI16 parseI32 _ = parseCompactI32 parseI64 _ = parseCompactI64 parseFloat _ = binaryFloat parseDouble _ = binaryDouble parseBool _ = do byte <- P.anyWord8 case byte of TTRUE -> pure True TFALSE -> pure False n -> fail $ "invalid boolean value: " ++ show n parseBoolF _ = pure parseByteString _ = getBuffer parseVarint BS.copy parseText _ = getBuffer parseVarint decodeUtf8 parseSkip _ TTRUE Nothing = P.skipN 1 parseSkip _ TFALSE Nothing = P.skipN 1 parseSkip _ TTRUE Just{} = pure () parseSkip _ TFALSE Just{} = pure () parseSkip _ TBYTE _ = P.skipN 1 parseSkip _ TI16 _ = void parseCompactI16 parseSkip _ TI32 _ = void parseCompactI32 parseSkip _ TI64 _ = void parseCompactI64 parseSkip _ TDOUBLE _ = P.skipN 8 parseSkip _ TSTRING _ = void $ parseVarint >>= P.skipN . fromIntegral parseSkip proxy TLIST _ = do byte <- anyWord8 let ty = byte .&. 0x0F len <- case byte `shiftR` 4 of 0x0F -> parseVarint size -> pure $ fromIntegral size void $ replicateM (fromIntegral len) (parseSkip proxy ty Nothing) parseSkip proxy TSET _ = parseSkip proxy TLIST Nothing parseSkip proxy TMAP _ = do len <- parseVarint if len == 0 then pure () else do byte <- anyWord8 let ktype = byte `shiftR` 4 vtype = byte .&. 0x0F void $ replicateM (fromIntegral len) $ (,) <$> parseSkip proxy ktype Nothing <*> parseSkip proxy vtype Nothing parseSkip proxy TSTRUCT _ = do -- The last id deosn't matter since we do not need to correctly parse field -- ids. We don't recognize this field, therefore it will be discarded anyway fieldBegin <- parseFieldBegin proxy 0 HashMap.empty case fieldBegin of FieldBegin ty _ bool -> parseSkip proxy ty (Just bool) *> parseSkip proxy TSTRUCT Nothing FieldEnd -> pure () parseSkip _ n _ = fail $ "unrecognized type code: " ++ show n -- Helpers --------------------------------------------------------------------- {-# INLINE compactFieldBegin #-} compactFieldBegin :: Type -> BoundedPrim (FieldId, FieldId) compactFieldBegin ty = condB (\(fid, lastId) -> fid > lastId && fid - lastId < 16) ((\(fid, lastId) -> fromIntegral (fid - lastId) `shiftL` 4 .|. ty) >$< liftFixedToBounded Prim.word8) ((\(fid, _) -> (ty, fid)) >$< liftFixedToBounded Prim.word8 >*< compactI16) compactI16 :: BoundedPrim Int16 compactI16 = i16ToZigZag >$< buildVarint2 compactI32 :: BoundedPrim Int32 compactI32 = i32ToZigZag >$< buildVarint4 compactI64 :: BoundedPrim Int64 compactI64 = i64ToZigZag >$< buildVarint parseCompactI16 :: Parser Int16 parseCompactI16 = zigZagToInt <$> parseVarint parseCompactI32 :: Parser Int32 parseCompactI32 = zigZagToInt <$> parseVarint parseCompactI64 :: Parser Int64 parseCompactI64 = zigZagToInt <$> parseVarint genListBegin :: Type -> [a] -> Builder genListBegin ty elems = (`primBounded` len) $ condB (< 15) ((\l -> (fromIntegral l `shiftL` 4) .|. ty) >$< liftFixedToBounded Prim.word8) ((\l -> (0xF0 .|. ty, fromIntegral l)) >$< (liftFixedToBounded Prim.word8 >*< buildVarint)) where len = length elems genMapBegin :: Type -> Type -> [a] -> Builder genMapBegin kt vt elems | null elems = B.word8 0x00 | otherwise = (`primBounded` len) $ (\l -> (l, kt `shiftL` 4 .|. vt)) >$< (buildVarint >*< liftFixedToBounded Prim.word8) where len = genericLength elems -- Variable Length Encoded Integers -------------------------------------------- -- Signed numbers must be converted to "Zig Zag" format before they can be -- serialized in the Varint format {-# INLINE iToZigZag #-} iToZigZag :: (Bits i, Integral i) => Int -> i -> Word iToZigZag s n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` s) {-# INLINE i16ToZigZag #-} i16ToZigZag :: Int16 -> Word i16ToZigZag n = 0xFFFF .&. iToZigZag 15 n {-# INLINE i32ToZigZag #-} i32ToZigZag :: Int32 -> Word i32ToZigZag n = 0xFFFFFFFF .&. iToZigZag 31 n {-# INLINE i64ToZigZag #-} i64ToZigZag :: Int64 -> Word i64ToZigZag = iToZigZag 63 -- This is janky, but BoundedPrim requires the size of the value it is building -- to be bounded. Varints *are* bounded, but the output size depends on their -- value. A 64-bit value can take up to 80 bits to encode, because each 7-bits -- in the input map to 8 bits in the output. We therefore need 10 builders, each -- one decreasing in potential maximum size. buildVarint :: BoundedPrim Word buildVarint = buildVarintBase buildVarint8 buildVarint8 :: BoundedPrim Word buildVarint8 = buildVarintBase buildVarint7 buildVarint7 :: BoundedPrim Word buildVarint7 = buildVarintBase buildVarint6 buildVarint6 :: BoundedPrim Word buildVarint6 = buildVarintBase buildVarint5 buildVarint5 :: BoundedPrim Word buildVarint5 = buildVarintBase buildVarint4 buildVarint4 :: BoundedPrim Word buildVarint4 = buildVarintBase buildVarint3 buildVarint3 :: BoundedPrim Word buildVarint3 = buildVarintBase buildVarint2 buildVarint2 :: BoundedPrim Word buildVarint2 = buildVarintBase buildVarint1 buildVarint1 :: BoundedPrim Word buildVarint1 = buildVarintBase buildVarint0 buildVarint0 :: BoundedPrim Word buildVarint0 = fromIntegral >$< liftFixedToBounded Prim.word8 {-# INLINE buildVarintBase #-} buildVarintBase :: BoundedPrim Word -> BoundedPrim Word buildVarintBase base = condB (\n -> n .&. complement 0x7F == 0) buildVarint0 ((\(W# n) -> (mkW8# (0x80## `or#` (n `and#` 0x7f##)), W# (n `uncheckedShiftRL#` 7#))) >$< (liftFixedToBounded Prim.word8 >*< base)) where #if __GLASGOW_HASKELL__ >= 902 mkW8# x = W8# (wordToWord8# x) #else mkW8# = W8# #endif {-# INLINE zigZagToInt #-} zigZagToInt :: (Bits i, Integral i) => Word -> i zigZagToInt w = fromIntegral (w `shiftR` 1) `xor` negate (fromIntegral w .&. 1) {-# INLINE parseVarint #-} parseVarint :: Parser Word parseVarint = go 0 0 where go !val !n = do w <- anyWord8 let newVal = val .|. (fromIntegral w .&. 0x7F) `shiftL` n if not (testBit w 7) then return newVal else go newVal (n + 7)