-- | Module defining the individual base wire types (e.g. VarInt, Fixed64).
--
-- They are used to represent the @unknownFields@ within the proto message.
--
-- Upstream docs:
-- <https://developers.google.com/protocol-buffers/docs/encoding#structure>
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.ProtoLens.Encoding.Wire
    ( Tag(..)
    , TaggedValue(..)
    , WireValue(..)
    , FieldSet
    , splitTypeAndTag
    , joinTypeAndTag
    , parseFieldSet
    , buildFieldSet
    , buildMessageSet
    , parseTaggedValueFromWire
    , parseMessageSetTaggedValueFromWire
    ) where

import Control.DeepSeq (NFData(..))
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
import qualified Data.ByteString as B
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif
import Data.Word (Word8, Word32, Word64)

import Data.ProtoLens.Encoding.Bytes

-- | A tag that identifies a particular field of the message when converting
-- to/from the wire format.
newtype Tag = Tag { Tag -> Int
unTag :: Int }
    deriving (Int -> Tag -> ShowS
[Tag] -> ShowS
Tag -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tag] -> ShowS
$cshowList :: [Tag] -> ShowS
show :: Tag -> String
$cshow :: Tag -> String
showsPrec :: Int -> Tag -> ShowS
$cshowsPrec :: Int -> Tag -> ShowS
Show, Tag -> Tag -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tag -> Tag -> Bool
$c/= :: Tag -> Tag -> Bool
== :: Tag -> Tag -> Bool
$c== :: Tag -> Tag -> Bool
Eq, Eq Tag
Tag -> Tag -> Bool
Tag -> Tag -> Ordering
Tag -> Tag -> Tag
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Tag -> Tag -> Tag
$cmin :: Tag -> Tag -> Tag
max :: Tag -> Tag -> Tag
$cmax :: Tag -> Tag -> Tag
>= :: Tag -> Tag -> Bool
$c>= :: Tag -> Tag -> Bool
> :: Tag -> Tag -> Bool
$c> :: Tag -> Tag -> Bool
<= :: Tag -> Tag -> Bool
$c<= :: Tag -> Tag -> Bool
< :: Tag -> Tag -> Bool
$c< :: Tag -> Tag -> Bool
compare :: Tag -> Tag -> Ordering
$ccompare :: Tag -> Tag -> Ordering
Ord, Integer -> Tag
Tag -> Tag
Tag -> Tag -> Tag
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Tag
$cfromInteger :: Integer -> Tag
signum :: Tag -> Tag
$csignum :: Tag -> Tag
abs :: Tag -> Tag
$cabs :: Tag -> Tag
negate :: Tag -> Tag
$cnegate :: Tag -> Tag
* :: Tag -> Tag -> Tag
$c* :: Tag -> Tag -> Tag
- :: Tag -> Tag -> Tag
$c- :: Tag -> Tag -> Tag
+ :: Tag -> Tag -> Tag
$c+ :: Tag -> Tag -> Tag
Num, Tag -> ()
forall a. (a -> ()) -> NFData a
rnf :: Tag -> ()
$crnf :: Tag -> ()
NFData)

-- | The encoding of some unknown field on the wire.
data WireValue
    = VarInt !Word64
    | Fixed64 !Word64
    | Lengthy !B.ByteString
    | StartGroup
    | EndGroup
    | Fixed32 !Word32
    deriving (WireValue -> WireValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WireValue -> WireValue -> Bool
$c/= :: WireValue -> WireValue -> Bool
== :: WireValue -> WireValue -> Bool
$c== :: WireValue -> WireValue -> Bool
Eq, Eq WireValue
WireValue -> WireValue -> Bool
WireValue -> WireValue -> Ordering
WireValue -> WireValue -> WireValue
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: WireValue -> WireValue -> WireValue
$cmin :: WireValue -> WireValue -> WireValue
max :: WireValue -> WireValue -> WireValue
$cmax :: WireValue -> WireValue -> WireValue
>= :: WireValue -> WireValue -> Bool
$c>= :: WireValue -> WireValue -> Bool
> :: WireValue -> WireValue -> Bool
$c> :: WireValue -> WireValue -> Bool
<= :: WireValue -> WireValue -> Bool
$c<= :: WireValue -> WireValue -> Bool
< :: WireValue -> WireValue -> Bool
$c< :: WireValue -> WireValue -> Bool
compare :: WireValue -> WireValue -> Ordering
$ccompare :: WireValue -> WireValue -> Ordering
Ord)

-- | A pair of an encoded field and a value.
data TaggedValue = TaggedValue !Tag !WireValue
    deriving (TaggedValue -> TaggedValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TaggedValue -> TaggedValue -> Bool
$c/= :: TaggedValue -> TaggedValue -> Bool
== :: TaggedValue -> TaggedValue -> Bool
$c== :: TaggedValue -> TaggedValue -> Bool
Eq, Eq TaggedValue
TaggedValue -> TaggedValue -> Bool
TaggedValue -> TaggedValue -> Ordering
TaggedValue -> TaggedValue -> TaggedValue
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TaggedValue -> TaggedValue -> TaggedValue
$cmin :: TaggedValue -> TaggedValue -> TaggedValue
max :: TaggedValue -> TaggedValue -> TaggedValue
$cmax :: TaggedValue -> TaggedValue -> TaggedValue
>= :: TaggedValue -> TaggedValue -> Bool
$c>= :: TaggedValue -> TaggedValue -> Bool
> :: TaggedValue -> TaggedValue -> Bool
$c> :: TaggedValue -> TaggedValue -> Bool
<= :: TaggedValue -> TaggedValue -> Bool
$c<= :: TaggedValue -> TaggedValue -> Bool
< :: TaggedValue -> TaggedValue -> Bool
$c< :: TaggedValue -> TaggedValue -> Bool
compare :: TaggedValue -> TaggedValue -> Ordering
$ccompare :: TaggedValue -> TaggedValue -> Ordering
Ord)

type FieldSet = [TaggedValue]

-- TaggedValue, WireValue and Tag are strict, so their NFData instances are
-- trivial:
instance NFData TaggedValue where
    rnf :: TaggedValue -> ()
rnf = (seq :: forall a b. a -> b -> b
`seq` ())

instance NFData WireValue where
    rnf :: WireValue -> ()
rnf = (seq :: forall a b. a -> b -> b
`seq` ())

buildTaggedValue :: TaggedValue -> Builder
buildTaggedValue :: TaggedValue -> Builder
buildTaggedValue (TaggedValue Tag
tag WireValue
wv) =
    Word64 -> Builder
putVarInt (Tag -> Word8 -> Word64
joinTypeAndTag Tag
tag (WireValue -> Word8
wireValueToInt WireValue
wv))
    forall a. Semigroup a => a -> a -> a
<> WireValue -> Builder
buildWireValue WireValue
wv

-- builds in legacy MessageSet format.
-- See https://github.com/protocolbuffers/protobuf/blob/dec4939439d9ca2adf2bb14edccf876c2587faf2/src/google/protobuf/descriptor.proto#L444
buildTaggedValueAsMessageSet :: TaggedValue -> Builder
buildTaggedValueAsMessageSet :: TaggedValue -> Builder
buildTaggedValueAsMessageSet (TaggedValue (Tag Int
t) WireValue
wv) =
    TaggedValue -> Builder
buildTaggedValue ( Tag -> WireValue -> TaggedValue
TaggedValue Tag
1 WireValue
StartGroup)
    forall a. Semigroup a => a -> a -> a
<> TaggedValue -> Builder
buildTaggedValue (Tag -> WireValue -> TaggedValue
TaggedValue Tag
2 (Word64 -> WireValue
VarInt forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
t))
    forall a. Semigroup a => a -> a -> a
<> TaggedValue -> Builder
buildTaggedValue (Tag -> WireValue -> TaggedValue
TaggedValue Tag
3 WireValue
wv)
    forall a. Semigroup a => a -> a -> a
<> TaggedValue -> Builder
buildTaggedValue (Tag -> WireValue -> TaggedValue
TaggedValue Tag
1 WireValue
EndGroup)

buildWireValue :: WireValue -> Builder
buildWireValue :: WireValue -> Builder
buildWireValue (VarInt Word64
w) = Word64 -> Builder
putVarInt Word64
w
buildWireValue (Fixed64 Word64
w) = Word64 -> Builder
putFixed64 Word64
w
buildWireValue (Fixed32 Word32
w) = Word32 -> Builder
putFixed32 Word32
w
buildWireValue (Lengthy ByteString
b) =
    Word64 -> Builder
putVarInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b)
    forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
putBytes ByteString
b
buildWireValue WireValue
StartGroup = forall a. Monoid a => a
mempty
buildWireValue WireValue
EndGroup = forall a. Monoid a => a
mempty

wireValueToInt :: WireValue -> Word8
wireValueToInt :: WireValue -> Word8
wireValueToInt VarInt{} = Word8
0
wireValueToInt Fixed64{} = Word8
1
wireValueToInt Lengthy{} = Word8
2
wireValueToInt StartGroup{} = Word8
3
wireValueToInt EndGroup{} = Word8
4
wireValueToInt Fixed32{} = Word8
5

parseTaggedValue :: Parser TaggedValue
parseTaggedValue :: Parser TaggedValue
parseTaggedValue = Parser Word64
getVarInt forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word64 -> Parser TaggedValue
parseTaggedValueFromWire

parseTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseTaggedValueFromWire Word64
t =
    let (Tag
tag, Word8
w) = Word64 -> (Tag, Word8)
splitTypeAndTag Word64
t
    in Tag -> WireValue -> TaggedValue
TaggedValue Tag
tag forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case Word8
w of
        Word8
0 -> Word64 -> WireValue
VarInt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Word64
getVarInt
        Word8
1 -> Word64 -> WireValue
Fixed64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Word64
getFixed64
        Word8
2 -> ByteString -> WireValue
Lengthy forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
                Word64
len <- Parser Word64
getVarInt
                Int -> Parser ByteString
getBytes forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len
        Word8
3 -> forall (m :: * -> *) a. Monad m => a -> m a
return WireValue
StartGroup
        Word8
4 -> forall (m :: * -> *) a. Monad m => a -> m a
return WireValue
EndGroup
        Word8
5 -> Word32 -> WireValue
Fixed32 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Word32
getFixed32
        Word8
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unknown wire type " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Word8
w

parseMessageSetTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseMessageSetTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseMessageSetTaggedValueFromWire Word64
t =
    Word64 -> Parser TaggedValue
parseTaggedValueFromWire Word64
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TaggedValue
v -> case TaggedValue
v of
        TaggedValue Tag
1 WireValue
StartGroup -> Parser TaggedValue
parseTaggedValue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TaggedValue
ft -> case TaggedValue
ft of
            TaggedValue Tag
2 (VarInt Word64
f) -> Parser TaggedValue
parseTaggedValue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TaggedValue
dt -> case TaggedValue
dt of
                TaggedValue Tag
3 (Lengthy ByteString
b) -> Parser TaggedValue
parseTaggedValue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TaggedValue
et -> case TaggedValue
et of
                    TaggedValue Tag
1 WireValue
EndGroup -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Tag -> WireValue -> TaggedValue
TaggedValue (Int -> Tag
Tag forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
f) (ByteString -> WireValue
Lengthy ByteString
b)
                    TaggedValue
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"missing end_group"
                TaggedValue
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"missing message"
            TaggedValue
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"missing field tag"
        TaggedValue
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return TaggedValue
v

splitTypeAndTag :: Word64 -> (Tag, Word8)
splitTypeAndTag :: Word64 -> (Tag, Word8)
splitTypeAndTag Word64
w = (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word64
w forall a. Bits a => a -> Int -> a
`shiftR` Int
3, forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
w forall a. Bits a => a -> a -> a
.&. Word64
7))

joinTypeAndTag :: Tag -> Word8 -> Word64
joinTypeAndTag :: Tag -> Word8 -> Word64
joinTypeAndTag (Tag Int
t) Word8
w = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
t forall a. Bits a => a -> Int -> a
`shiftL` Int
3 forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w

parseFieldSet :: Parser FieldSet
parseFieldSet :: Parser FieldSet
parseFieldSet = FieldSet -> Parser FieldSet
loop []
  where
    loop :: FieldSet -> Parser FieldSet
loop FieldSet
ws = do
        Bool
end <- Parser Bool
atEnd
        if Bool
end
            then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a. [a] -> [a]
reverse FieldSet
ws
            else do
                !TaggedValue
w <- Parser TaggedValue
parseTaggedValue
                FieldSet -> Parser FieldSet
loop (TaggedValue
wforall a. a -> [a] -> [a]
:FieldSet
ws)

buildFieldSet :: FieldSet -> Builder
buildFieldSet :: FieldSet -> Builder
buildFieldSet = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map TaggedValue -> Builder
buildTaggedValue 

buildMessageSet :: FieldSet -> Builder
buildMessageSet :: FieldSet -> Builder
buildMessageSet = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map TaggedValue -> Builder
buildTaggedValueAsMessageSet