Add ability for custom a MessagePrinter to default back to the default implementation.
Prior to the CL we cannot use the printer as that would trigger a recursion. So the only way out would be to create a new printer that has the same abilities as the current printer.
PiperOrigin-RevId: 489024852
diff --git a/src/google/protobuf/text_format.cc b/src/google/protobuf/text_format.cc
index 7ab3d48..b8ddd08 100644
--- a/src/google/protobuf/text_format.cc
+++ b/src/google/protobuf/text_format.cc
@@ -1470,15 +1470,15 @@
// error.)
bool failed() const { return failed_; }
- void PrintMaybeWithMarker(absl::string_view text) {
+ void PrintMaybeWithMarker(MarkerToken, absl::string_view text) override {
Print(text.data(), text.size());
if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker);
}
}
- void PrintMaybeWithMarker(absl::string_view text_head,
- absl::string_view text_tail) {
+ void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
+ absl::string_view text_tail) override {
Print(text_head.data(), text_head.size());
if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker);
@@ -1577,13 +1577,10 @@
void PrintMessageStart(const Message& /*message*/, int /*field_index*/,
int /*field_count*/, bool single_line_mode,
BaseTextGenerator* generator) const override {
- // This is safe as only TextGenerator is used with
- // DebugStringFieldValuePrinter.
- TextGenerator* text_generator = static_cast<TextGenerator*>(generator);
if (single_line_mode) {
- text_generator->PrintMaybeWithMarker(" ", "{ ");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
- text_generator->PrintMaybeWithMarker(" ", "{\n");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
}
}
};
@@ -2174,7 +2171,7 @@
} // namespace
bool TextFormat::Printer::PrintAny(const Message& message,
- TextGenerator* generator) const {
+ BaseTextGenerator* generator) const {
const FieldDescriptor* type_url_field;
const FieldDescriptor* value_field;
if (!internal::GetAnyFieldDescriptors(message, &type_url_field,
@@ -2222,7 +2219,7 @@
}
void TextFormat::Printer::Print(const Message& message,
- TextGenerator* generator) const {
+ BaseTextGenerator* generator) const {
const Reflection* reflection = message.GetReflection();
if (!reflection) {
// This message does not provide any way to describe its structure.
@@ -2242,10 +2239,20 @@
itr->second->Print(message, single_line_mode_, generator);
return;
}
+ PrintMessage(message, generator);
+}
+
+void TextFormat::Printer::PrintMessage(const Message& message,
+ BaseTextGenerator* generator) const {
+ if (generator == nullptr) {
+ return;
+ }
+ const Descriptor* descriptor = message.GetDescriptor();
if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ &&
PrintAny(message, generator)) {
return;
}
+ const Reflection* reflection = message.GetReflection();
std::vector<const FieldDescriptor*> fields;
if (descriptor->options().map_entry()) {
fields.push_back(descriptor->field(0));
@@ -2460,7 +2467,7 @@
void TextFormat::Printer::PrintField(const Message& message,
const Reflection* reflection,
const FieldDescriptor* field,
- TextGenerator* generator) const {
+ BaseTextGenerator* generator) const {
if (use_short_repeated_primitives_ && field->is_repeated() &&
field->cpp_type() != FieldDescriptor::CPPTYPE_STRING &&
field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -2508,7 +2515,7 @@
printer->PrintMessageEnd(sub_message, field_index, count,
single_line_mode_, generator);
} else {
- generator->PrintMaybeWithMarker(": ");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ");
// Write the field value.
PrintFieldValue(message, reflection, field, field_index, generator);
if (single_line_mode_) {
@@ -2528,12 +2535,12 @@
void TextFormat::Printer::PrintShortRepeatedField(
const Message& message, const Reflection* reflection,
- const FieldDescriptor* field, TextGenerator* generator) const {
+ const FieldDescriptor* field, BaseTextGenerator* generator) const {
// Print primitive repeated field in short form.
int size = reflection->FieldSize(message, field);
PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection,
field, generator);
- generator->PrintMaybeWithMarker(": ", "[");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ", "[");
for (int i = 0; i < size; i++) {
if (i > 0) generator->PrintLiteral(", ");
PrintFieldValue(message, reflection, field, i, generator);
@@ -2549,7 +2556,7 @@
int field_index, int field_count,
const Reflection* reflection,
const FieldDescriptor* field,
- TextGenerator* generator) const {
+ BaseTextGenerator* generator) const {
// if use_field_number_ is true, prints field number instead
// of field name.
if (use_field_number_) {
@@ -2566,7 +2573,7 @@
const Reflection* reflection,
const FieldDescriptor* field,
int index,
- TextGenerator* generator) const {
+ BaseTextGenerator* generator) const {
GOOGLE_DCHECK(field->is_repeated() || (index == -1))
<< "Index must be -1 for non-repeated fields";
@@ -2687,7 +2694,7 @@
}
void TextFormat::Printer::PrintUnknownFields(
- const UnknownFieldSet& unknown_fields, TextGenerator* generator,
+ const UnknownFieldSet& unknown_fields, BaseTextGenerator* generator,
int recursion_budget) const {
for (int i = 0; i < unknown_fields.field_count(); i++) {
const UnknownField& field = unknown_fields.field(i);
@@ -2696,7 +2703,7 @@
switch (field.type()) {
case UnknownField::TYPE_VARINT:
generator->PrintString(field_number);
- generator->PrintMaybeWithMarker(": ");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ");
generator->PrintString(absl::StrCat(field.varint()));
if (single_line_mode_) {
generator->PrintLiteral(" ");
@@ -2706,7 +2713,7 @@
break;
case UnknownField::TYPE_FIXED32: {
generator->PrintString(field_number);
- generator->PrintMaybeWithMarker(": ", "0x");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString(
absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8)));
if (single_line_mode_) {
@@ -2718,7 +2725,7 @@
}
case UnknownField::TYPE_FIXED64: {
generator->PrintString(field_number);
- generator->PrintMaybeWithMarker(": ", "0x");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString(
absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16)));
if (single_line_mode_) {
@@ -2743,9 +2750,9 @@
// This field is parseable as a Message.
// So it is probably an embedded message.
if (single_line_mode_) {
- generator->PrintMaybeWithMarker(" ", "{ ");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
- generator->PrintMaybeWithMarker(" ", "{\n");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent();
}
PrintUnknownFields(embedded_unknown_fields, generator,
@@ -2759,7 +2766,7 @@
} else {
// This field is not parseable as a Message (or we ran out of
// recursion budget). So it is probably just a plain string.
- generator->PrintMaybeWithMarker(": ", "\"");
+ generator->PrintMaybeWithMarker(MarkerToken(), ": ", "\"");
generator->PrintString(absl::CEscape(value));
if (single_line_mode_) {
generator->PrintLiteral("\" ");
@@ -2772,9 +2779,9 @@
case UnknownField::TYPE_GROUP:
generator->PrintString(field_number);
if (single_line_mode_) {
- generator->PrintMaybeWithMarker(" ", "{ ");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
- generator->PrintMaybeWithMarker(" ", "{\n");
+ generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent();
}
// For groups, we recurse without checking the budget. This is OK,
diff --git a/src/google/protobuf/text_format.h b/src/google/protobuf/text_format.h
index fbcbea4..2cd4024 100644
--- a/src/google/protobuf/text_format.h
+++ b/src/google/protobuf/text_format.h
@@ -115,7 +115,22 @@
const FieldDescriptor* field, int index,
std::string* output);
+ // Forward declare `Printer` for `BaseTextGenerator::MarkerToken` which
+ // restricts some methods of `BaseTextGenerator` to the class `Printer`.
+ class Printer;
+
class PROTOBUF_EXPORT BaseTextGenerator {
+ private:
+ // Passkey (go/totw/134#what-about-stdshared-ptr) that allows `Printer`
+ // (but not derived classes) to call `PrintMaybeWithMarker` and its
+ // `Printer::TextGenerator` to overload it.
+ // This prevents users from bypassing the marker generation.
+ class MarkerToken {
+ private:
+ explicit MarkerToken() = default; // 'explicit' prevents aggregate init.
+ friend class Printer;
+ };
+
public:
virtual ~BaseTextGenerator();
@@ -133,6 +148,20 @@
void PrintLiteral(const char (&text)[n]) {
Print(text, n - 1); // n includes the terminating zero character.
}
+
+ // Internal to Printer, access regulated by `MarkerToken`.
+ virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text) {
+ Print(text.data(), text.size());
+ }
+
+ // Internal to Printer, access regulated by `MarkerToken`.
+ virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
+ absl::string_view text_tail) {
+ Print(text_head.data(), text_head.size());
+ Print(text_tail.data(), text_tail.size());
+ }
+
+ friend class Printer;
};
// The default printer that converts scalar values from fields into their
@@ -381,6 +410,14 @@
bool RegisterMessagePrinter(const Descriptor* descriptor,
const MessagePrinter* printer);
+ // Default printing for messages, which allows registered message printers
+ // to fall back to default printing without losing the ability to control
+ // sub-messages or fields.
+ // NOTE: If the passed in `text_generaor` is not actually the current
+ // `TextGenerator`, then no output will be produced.
+ void PrintMessage(const Message& message,
+ BaseTextGenerator* generator) const;
+
private:
friend std::string Message::DebugString() const;
friend std::string Message::ShortDebugString() const;
@@ -404,6 +441,7 @@
// Forward declaration of an internal class used to print the text
// output to the OutputStream (see text_format.cc for implementation).
class TextGenerator;
+ using MarkerToken = BaseTextGenerator::MarkerToken;
// Forward declaration of an internal class used to print field values for
// DebugString APIs (see text_format.cc for implementation).
@@ -417,40 +455,40 @@
// Internal Print method, used for writing to the OutputStream via
// the TextGenerator class.
- void Print(const Message& message, TextGenerator* generator) const;
+ void Print(const Message& message, BaseTextGenerator* generator) const;
// Print a single field.
void PrintField(const Message& message, const Reflection* reflection,
const FieldDescriptor* field,
- TextGenerator* generator) const;
+ BaseTextGenerator* generator) const;
// Print a repeated primitive field in short form.
void PrintShortRepeatedField(const Message& message,
const Reflection* reflection,
const FieldDescriptor* field,
- TextGenerator* generator) const;
+ BaseTextGenerator* generator) const;
// Print the name of a field -- i.e. everything that comes before the
// ':' for a single name/value pair.
void PrintFieldName(const Message& message, int field_index,
int field_count, const Reflection* reflection,
const FieldDescriptor* field,
- TextGenerator* generator) const;
+ BaseTextGenerator* generator) const;
// Outputs a textual representation of the value of the field supplied on
// the message supplied or the default value if not set.
void PrintFieldValue(const Message& message, const Reflection* reflection,
const FieldDescriptor* field, int index,
- TextGenerator* generator) const;
+ BaseTextGenerator* generator) const;
// Print the fields in an UnknownFieldSet. They are printed by tag number
// only. Embedded messages are heuristically identified by attempting to
// parse them (subject to the recursion budget).
void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
- TextGenerator* generator,
+ BaseTextGenerator* generator,
int recursion_budget) const;
- bool PrintAny(const Message& message, TextGenerator* generator) const;
+ bool PrintAny(const Message& message, BaseTextGenerator* generator) const;
const FastFieldValuePrinter* GetFieldPrinter(
const FieldDescriptor* field) const {
diff --git a/src/google/protobuf/text_format_unittest.cc b/src/google/protobuf/text_format_unittest.cc
index 6dcc044..1bca2df 100644
--- a/src/google/protobuf/text_format_unittest.cc
+++ b/src/google/protobuf/text_format_unittest.cc
@@ -780,15 +780,24 @@
~CustomNestedMessagePrinter() override {}
void Print(const Message& message, bool single_line_mode,
TextFormat::BaseTextGenerator* generator) const override {
- generator->PrintLiteral("custom");
+ generator->PrintLiteral("// custom\n");
+ if (printer_ != nullptr) {
+ printer_->PrintMessage(message, generator);
+ }
}
+
+ void SetPrinter(TextFormat::Printer* printer) { printer_ = printer; }
+
+ private:
+ TextFormat::Printer* printer_ = nullptr;
};
TEST_F(TextFormatTest, CustomMessagePrinter) {
TextFormat::Printer printer;
+ auto* custom_printer = new CustomNestedMessagePrinter;
printer.RegisterMessagePrinter(
unittest::TestAllTypes::NestedMessage::default_instance().descriptor(),
- new CustomNestedMessagePrinter);
+ custom_printer);
unittest::TestAllTypes message;
std::string text;
@@ -797,7 +806,11 @@
message.mutable_optional_nested_message()->set_bb(1);
EXPECT_TRUE(printer.PrintToString(message, &text));
- EXPECT_EQ("optional_nested_message {\n custom}\n", text);
+ EXPECT_EQ("optional_nested_message {\n // custom\n}\n", text);
+
+ custom_printer->SetPrinter(&printer);
+ EXPECT_TRUE(printer.PrintToString(message, &text));
+ EXPECT_EQ("optional_nested_message {\n // custom\n bb: 1\n}\n", text);
}
TEST_F(TextFormatTest, ParseBasic) {