{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.ProtoLens.TextFormat(
showMessage,
showMessageWithRegistry,
showMessageShort,
pprintMessage,
pprintMessageWithRegistry,
readMessage,
readMessageWithRegistry,
readMessageOrDie,
) where
import Lens.Family2 ((&),(^.),(.~), set, over, view)
import Control.Arrow (left)
import Data.Bifunctor (first)
import qualified Data.ByteString
import Data.Char (isPrint, isAscii, chr)
import Data.Foldable (foldlM, foldl')
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import Data.Proxy (Proxy(Proxy))
import Data.ProtoLens.Encoding (encodeMessage, decodeMessage)
import qualified Data.Set as Set
import qualified Data.Text.Encoding as Text
import qualified Data.Text.Lazy as Lazy
import qualified Data.Text as Text (unpack)
import Numeric (showOct)
import Text.Parsec (parse)
import Text.PrettyPrint
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif
import Data.ProtoLens.Encoding.Wire
import Data.ProtoLens.Message
import qualified Data.ProtoLens.TextFormat.Parser as Parser
pprintMessage :: Message msg => msg -> Doc
pprintMessage = pprintMessageWithRegistry mempty
pprintMessageWithRegistry :: Message msg => Registry -> msg -> Doc
pprintMessageWithRegistry reg msg
= sep $ concatMap (pprintField reg msg) allFields
++ map pprintTaggedValue (msg ^. unknownFields)
showMessage :: Message msg => msg -> String
showMessage = render . pprintMessage
showMessageWithRegistry :: Message msg => Registry -> msg -> String
showMessageWithRegistry reg = render . pprintMessageWithRegistry reg
showMessageShort :: Message msg => msg -> String
showMessageShort = renderStyle (Style OneLineMode maxBound 1.5) . pprintMessage
pprintField :: Registry -> msg -> FieldDescriptor msg -> [Doc]
pprintField reg msg (FieldDescriptor name typeDescr accessor)
= map (pprintFieldValue reg name typeDescr) $ case accessor of
PlainField d f
| Optional <- d, val == fieldDefault -> []
| otherwise -> [val]
where val = msg ^. f
OptionalField f -> catMaybes [msg ^. f]
RepeatedField _ f -> msg ^. f
MapField k v f -> pairToMsg <$> Map.assocs (msg ^. f)
where pairToMsg (x,y) = def & k .~ x
& v .~ y
pprintFieldValue :: Registry -> String -> FieldTypeDescriptor value -> value -> Doc
pprintFieldValue reg name field@(MessageField MessageType) m
| Just AnyMessageDescriptor { anyTypeUrlLens, anyValueLens } <- matchAnyMessage field,
typeUri <- view anyTypeUrlLens m,
fieldData <- view anyValueLens m,
Just (SomeMessageType (Proxy :: Proxy value')) <- lookupRegistered typeUri reg,
Right (anyValue :: value') <- decodeMessage fieldData =
pprintSubmessage name
$ sep
[ lbrack <> text (Text.unpack typeUri) <> rbrack <+> lbrace
, nest 2 (pprintMessageWithRegistry reg anyValue)
, rbrace ]
| otherwise =
pprintSubmessage name (pprintMessageWithRegistry reg m)
pprintFieldValue reg name (MessageField GroupType) m
= pprintSubmessage name (pprintMessageWithRegistry reg m)
pprintFieldValue _ name (ScalarField f) x = named name $ pprintScalarValue f x
named :: String -> Doc -> Doc
named n x = text n <> colon <+> x
pprintScalarValue :: ScalarField value -> value -> Doc
pprintScalarValue EnumField x = text (showEnum x)
pprintScalarValue Int32Field x = primField x
pprintScalarValue Int64Field x = primField x
pprintScalarValue UInt32Field x = primField x
pprintScalarValue UInt64Field x = primField x
pprintScalarValue SInt32Field x = primField x
pprintScalarValue SInt64Field x = primField x
pprintScalarValue Fixed32Field x = primField x
pprintScalarValue Fixed64Field x = primField x
pprintScalarValue SFixed32Field x = primField x
pprintScalarValue SFixed64Field x = primField x
pprintScalarValue FloatField x = primField x
pprintScalarValue DoubleField x = primField x
pprintScalarValue BoolField x = boolValue x
pprintScalarValue StringField x = pprintByteString (Text.encodeUtf8 x)
pprintScalarValue BytesField x = pprintByteString x
pprintSubmessage :: String -> Doc -> Doc
pprintSubmessage name contents =
sep [text name <+> lbrace, nest 2 contents, rbrace]
pprintByteString :: Data.ByteString.ByteString -> Doc
pprintByteString x = char '\"'
<> text (concatMap escape $ Data.ByteString.unpack x) <> char '\"'
where escape w8 | ch == '\n' = "\\n"
| ch == '\r' = "\\r"
| ch == '\t' = "\\t"
| ch == '\"' = "\\\""
| ch == '\'' = "\\\'"
| ch == '\\' = "\\\\"
| isPrint ch && isAscii ch = ch : ""
| otherwise = "\\" ++ pad (showOct w8 "")
where
ch = chr $ fromIntegral w8
pad str = replicate (3 - length str) '0' ++ str
primField :: Show value => value -> Doc
primField x = text (show x)
boolValue :: Bool -> Doc
boolValue True = text "true"
boolValue False = text "false"
pprintTaggedValue :: TaggedValue -> Doc
pprintTaggedValue (TaggedValue t (WireValue v x)) = case v of
VarInt -> named name $ primField x
Fixed64 -> named name $ primField x
Fixed32 -> named name $ primField x
Lengthy -> case decodeFieldSet x of
Left _ -> named name $ pprintByteString x
Right ts -> pprintSubmessage name
$ sep $ map pprintTaggedValue ts
StartGroup -> named name $ text "start_group"
EndGroup -> named name $ text "end_group"
where
name = show (unTag t)
readMessage :: Message msg => Lazy.Text -> Either String msg
readMessage = readMessageWithRegistry mempty
readMessageOrDie :: Message msg => Lazy.Text -> msg
readMessageOrDie str = case readMessage str of
Left e -> error $ "readMessageOrDie: " ++ e
Right x -> x
readMessageWithRegistry :: Message msg => Registry -> Lazy.Text -> Either String msg
readMessageWithRegistry reg str = left show (parse Parser.parser "" str) >>= buildMessage reg
buildMessage :: forall msg . Message msg => Registry -> Parser.Message -> Either String msg
buildMessage reg fields
| missing <- missingFields (Proxy @msg) fields, not $ null missing
= Left $ "Missing fields " ++ show missing
| otherwise = reverseRepeatedFields fieldsByTag
<$> buildMessageFromDescriptor reg def fields
missingFields :: forall msg . Message msg => Proxy msg -> Parser.Message -> [String]
missingFields _ = Set.toList . foldl' deleteField requiredFieldNames
where
requiredFieldNames :: Set.Set String
requiredFieldNames = Set.fromList $ Map.keys
$ Map.filter isRequired
$ fieldsByTextFormatName @msg
deleteField :: Set.Set String -> Parser.Field -> Set.Set String
deleteField fs (Parser.Field (Parser.Key name) _) = Set.delete name fs
deleteField fs (Parser.Field (Parser.UnknownKey n) _)
| Just d <- Map.lookup (Tag (fromIntegral n)) (fieldsByTag @msg)
= Set.delete (fieldDescriptorName d) fs
deleteField fs _ = fs
buildMessageFromDescriptor
:: Message msg => Registry -> msg -> Parser.Message -> Either String msg
buildMessageFromDescriptor reg = foldlM (addField reg)
addField :: forall msg . Message msg => Registry -> msg -> Parser.Field -> Either String msg
addField reg msg (Parser.Field key rawValue) = do
FieldDescriptor name typeDescriptor accessor <- getFieldDescriptor
value <- makeValue name reg typeDescriptor rawValue
return $ modifyField accessor value msg
where
getFieldDescriptor
| Parser.Key name <- key, Just f <- Map.lookup name
fieldsByTextFormatName
= return f
| Parser.UnknownKey tag <- key, Just f <- Map.lookup (fromIntegral tag)
fieldsByTag
= return f
| otherwise = Left $ "Unrecognized field " ++ show key
modifyField :: FieldAccessor msg value -> value -> msg -> msg
modifyField (PlainField _ f) value = set f value
modifyField (OptionalField f) value = set f (Just value)
modifyField (RepeatedField _ f) value = over f (value :)
modifyField (MapField key value f) mapElem
= over f (Map.insert (mapElem ^. key) (mapElem ^. value))
makeValue
:: forall value
. String
-> Registry
-> FieldTypeDescriptor value
-> Parser.Value
-> Either String value
makeValue name _ (ScalarField f) v =
first (("Error parsing field " ++ show name ++ ": ") ++) $ makeScalarValue f v
makeValue name reg field@(MessageField MessageType) (Parser.MessageValue (Just typeUri) x)
| Just AnyMessageDescriptor { anyTypeUrlLens, anyValueLens } <- matchAnyMessage field =
case lookupRegistered typeUri reg of
Nothing -> Left $ "Could not decode google.protobuf.Any for field "
++ show name ++ ": unregistered type URI "
++ show typeUri
Just (SomeMessageType (Proxy :: Proxy value')) ->
case buildMessage reg x :: Either String value' of
Left err -> Left err
Right value' -> Right (def & anyTypeUrlLens .~ typeUri
& anyValueLens .~ encodeMessage value')
| otherwise = Left ("Type mismatch parsing explicitly typed message. Expected " ++
show (messageName (Proxy @value)) ++
", got " ++ show typeUri)
makeValue _ reg (MessageField _) (Parser.MessageValue _ x) = buildMessage reg x
makeValue name _ (MessageField _) val =
Left $ "Type mismatch for field " ++ show name ++
": expected message, found " ++ show val
makeScalarValue :: ScalarField value -> Parser.Value -> Either String value
makeScalarValue Int32Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue Int64Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue UInt32Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue UInt64Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue SInt32Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue SInt64Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue Fixed32Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue Fixed64Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue SFixed32Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue SFixed64Field (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue FloatField (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue DoubleField (Parser.IntValue x) = Right (fromInteger x)
makeScalarValue BoolField (Parser.IntValue x)
| x == 0 = Right False
| x == 1 = Right True
| otherwise = Left $ "Unrecognized bool value " ++ show x
makeScalarValue DoubleField (Parser.DoubleValue x) = Right x
makeScalarValue FloatField (Parser.DoubleValue x) = Right (realToFrac x)
makeScalarValue BoolField (Parser.EnumValue x)
| x == "true" = Right True
| x == "false" = Right False
| otherwise = Left $ "Unrecognized bool value " ++ show x
makeScalarValue StringField (Parser.ByteStringValue x) = Right (Text.decodeUtf8 x)
makeScalarValue BytesField (Parser.ByteStringValue x) = Right x
makeScalarValue EnumField (Parser.IntValue x) =
maybe (Left $ "Unrecognized enum value " ++ show x) Right
(maybeToEnum $ fromInteger x)
makeScalarValue EnumField (Parser.EnumValue x) =
maybe (Left $ "Unrecognized enum value " ++ show x) Right
(readEnum x)
makeScalarValue f val = Left $ "Type mismatch: " ++ show (f, val)