Add ability for custom a MessagePrinter to default back to the default implementation.

Prior to the CL we cannot use the printer as that would trigger a recursion. So the only way out would be to create a new printer that has the same abilities as the current printer.

PiperOrigin-RevId: 489024852
diff --git a/src/google/protobuf/text_format.cc b/src/google/protobuf/text_format.cc
index 7ab3d48..b8ddd08 100644
--- a/src/google/protobuf/text_format.cc
+++ b/src/google/protobuf/text_format.cc
@@ -1470,15 +1470,15 @@
   // error.)
   bool failed() const { return failed_; }
 
-  void PrintMaybeWithMarker(absl::string_view text) {
+  void PrintMaybeWithMarker(MarkerToken, absl::string_view text) override {
     Print(text.data(), text.size());
     if (ConsumeInsertSilentMarker()) {
       PrintLiteral(internal::kDebugStringSilentMarker);
     }
   }
 
-  void PrintMaybeWithMarker(absl::string_view text_head,
-                            absl::string_view text_tail) {
+  void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
+                            absl::string_view text_tail) override {
     Print(text_head.data(), text_head.size());
     if (ConsumeInsertSilentMarker()) {
       PrintLiteral(internal::kDebugStringSilentMarker);
@@ -1577,13 +1577,10 @@
   void PrintMessageStart(const Message& /*message*/, int /*field_index*/,
                          int /*field_count*/, bool single_line_mode,
                          BaseTextGenerator* generator) const override {
-    // This is safe as only TextGenerator is used with
-    // DebugStringFieldValuePrinter.
-    TextGenerator* text_generator = static_cast<TextGenerator*>(generator);
     if (single_line_mode) {
-      text_generator->PrintMaybeWithMarker(" ", "{ ");
+      generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
     } else {
-      text_generator->PrintMaybeWithMarker(" ", "{\n");
+      generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
     }
   }
 };
@@ -2174,7 +2171,7 @@
 }  // namespace
 
 bool TextFormat::Printer::PrintAny(const Message& message,
-                                   TextGenerator* generator) const {
+                                   BaseTextGenerator* generator) const {
   const FieldDescriptor* type_url_field;
   const FieldDescriptor* value_field;
   if (!internal::GetAnyFieldDescriptors(message, &type_url_field,
@@ -2222,7 +2219,7 @@
 }
 
 void TextFormat::Printer::Print(const Message& message,
-                                TextGenerator* generator) const {
+                                BaseTextGenerator* generator) const {
   const Reflection* reflection = message.GetReflection();
   if (!reflection) {
     // This message does not provide any way to describe its structure.
@@ -2242,10 +2239,20 @@
     itr->second->Print(message, single_line_mode_, generator);
     return;
   }
+  PrintMessage(message, generator);
+}
+
+void TextFormat::Printer::PrintMessage(const Message& message,
+                                       BaseTextGenerator* generator) const {
+  if (generator == nullptr) {
+    return;
+  }
+  const Descriptor* descriptor = message.GetDescriptor();
   if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ &&
       PrintAny(message, generator)) {
     return;
   }
+  const Reflection* reflection = message.GetReflection();
   std::vector<const FieldDescriptor*> fields;
   if (descriptor->options().map_entry()) {
     fields.push_back(descriptor->field(0));
@@ -2460,7 +2467,7 @@
 void TextFormat::Printer::PrintField(const Message& message,
                                      const Reflection* reflection,
                                      const FieldDescriptor* field,
-                                     TextGenerator* generator) const {
+                                     BaseTextGenerator* generator) const {
   if (use_short_repeated_primitives_ && field->is_repeated() &&
       field->cpp_type() != FieldDescriptor::CPPTYPE_STRING &&
       field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -2508,7 +2515,7 @@
       printer->PrintMessageEnd(sub_message, field_index, count,
                                single_line_mode_, generator);
     } else {
-      generator->PrintMaybeWithMarker(": ");
+      generator->PrintMaybeWithMarker(MarkerToken(), ": ");
       // Write the field value.
       PrintFieldValue(message, reflection, field, field_index, generator);
       if (single_line_mode_) {
@@ -2528,12 +2535,12 @@
 
 void TextFormat::Printer::PrintShortRepeatedField(
     const Message& message, const Reflection* reflection,
-    const FieldDescriptor* field, TextGenerator* generator) const {
+    const FieldDescriptor* field, BaseTextGenerator* generator) const {
   // Print primitive repeated field in short form.
   int size = reflection->FieldSize(message, field);
   PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection,
                  field, generator);
-  generator->PrintMaybeWithMarker(": ", "[");
+  generator->PrintMaybeWithMarker(MarkerToken(), ": ", "[");
   for (int i = 0; i < size; i++) {
     if (i > 0) generator->PrintLiteral(", ");
     PrintFieldValue(message, reflection, field, i, generator);
@@ -2549,7 +2556,7 @@
                                          int field_index, int field_count,
                                          const Reflection* reflection,
                                          const FieldDescriptor* field,
-                                         TextGenerator* generator) const {
+                                         BaseTextGenerator* generator) const {
   // if use_field_number_ is true, prints field number instead
   // of field name.
   if (use_field_number_) {
@@ -2566,7 +2573,7 @@
                                           const Reflection* reflection,
                                           const FieldDescriptor* field,
                                           int index,
-                                          TextGenerator* generator) const {
+                                          BaseTextGenerator* generator) const {
   GOOGLE_DCHECK(field->is_repeated() || (index == -1))
       << "Index must be -1 for non-repeated fields";
 
@@ -2687,7 +2694,7 @@
 }
 
 void TextFormat::Printer::PrintUnknownFields(
-    const UnknownFieldSet& unknown_fields, TextGenerator* generator,
+    const UnknownFieldSet& unknown_fields, BaseTextGenerator* generator,
     int recursion_budget) const {
   for (int i = 0; i < unknown_fields.field_count(); i++) {
     const UnknownField& field = unknown_fields.field(i);
@@ -2696,7 +2703,7 @@
     switch (field.type()) {
       case UnknownField::TYPE_VARINT:
         generator->PrintString(field_number);
-        generator->PrintMaybeWithMarker(": ");
+        generator->PrintMaybeWithMarker(MarkerToken(), ": ");
         generator->PrintString(absl::StrCat(field.varint()));
         if (single_line_mode_) {
           generator->PrintLiteral(" ");
@@ -2706,7 +2713,7 @@
         break;
       case UnknownField::TYPE_FIXED32: {
         generator->PrintString(field_number);
-        generator->PrintMaybeWithMarker(": ", "0x");
+        generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
         generator->PrintString(
             absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8)));
         if (single_line_mode_) {
@@ -2718,7 +2725,7 @@
       }
       case UnknownField::TYPE_FIXED64: {
         generator->PrintString(field_number);
-        generator->PrintMaybeWithMarker(": ", "0x");
+        generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
         generator->PrintString(
             absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16)));
         if (single_line_mode_) {
@@ -2743,9 +2750,9 @@
           // This field is parseable as a Message.
           // So it is probably an embedded message.
           if (single_line_mode_) {
-            generator->PrintMaybeWithMarker(" ", "{ ");
+            generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
           } else {
-            generator->PrintMaybeWithMarker(" ", "{\n");
+            generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
             generator->Indent();
           }
           PrintUnknownFields(embedded_unknown_fields, generator,
@@ -2759,7 +2766,7 @@
         } else {
           // This field is not parseable as a Message (or we ran out of
           // recursion budget). So it is probably just a plain string.
-          generator->PrintMaybeWithMarker(": ", "\"");
+          generator->PrintMaybeWithMarker(MarkerToken(), ": ", "\"");
           generator->PrintString(absl::CEscape(value));
           if (single_line_mode_) {
             generator->PrintLiteral("\" ");
@@ -2772,9 +2779,9 @@
       case UnknownField::TYPE_GROUP:
         generator->PrintString(field_number);
         if (single_line_mode_) {
-          generator->PrintMaybeWithMarker(" ", "{ ");
+          generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
         } else {
-          generator->PrintMaybeWithMarker(" ", "{\n");
+          generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
           generator->Indent();
         }
         // For groups, we recurse without checking the budget. This is OK,
diff --git a/src/google/protobuf/text_format.h b/src/google/protobuf/text_format.h
index fbcbea4..2cd4024 100644
--- a/src/google/protobuf/text_format.h
+++ b/src/google/protobuf/text_format.h
@@ -115,7 +115,22 @@
                                       const FieldDescriptor* field, int index,
                                       std::string* output);
 
+  // Forward declare `Printer` for `BaseTextGenerator::MarkerToken` which
+  // restricts some methods of `BaseTextGenerator` to the class `Printer`.
+  class Printer;
+
   class PROTOBUF_EXPORT BaseTextGenerator {
+   private:
+    // Passkey (go/totw/134#what-about-stdshared-ptr) that allows `Printer`
+    // (but not derived classes) to call `PrintMaybeWithMarker` and its
+    // `Printer::TextGenerator` to overload it.
+    // This prevents users from bypassing the marker generation.
+    class MarkerToken {
+     private:
+      explicit MarkerToken() = default;  // 'explicit' prevents aggregate init.
+      friend class Printer;
+    };
+
    public:
     virtual ~BaseTextGenerator();
 
@@ -133,6 +148,20 @@
     void PrintLiteral(const char (&text)[n]) {
       Print(text, n - 1);  // n includes the terminating zero character.
     }
+
+    // Internal to Printer, access regulated by `MarkerToken`.
+    virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text) {
+      Print(text.data(), text.size());
+    }
+
+    // Internal to Printer, access regulated by `MarkerToken`.
+    virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
+                                      absl::string_view text_tail) {
+      Print(text_head.data(), text_head.size());
+      Print(text_tail.data(), text_tail.size());
+    }
+
+    friend class Printer;
   };
 
   // The default printer that converts scalar values from fields into their
@@ -381,6 +410,14 @@
     bool RegisterMessagePrinter(const Descriptor* descriptor,
                                 const MessagePrinter* printer);
 
+    // Default printing for messages, which allows registered message printers
+    // to fall back to default printing without losing the ability to control
+    // sub-messages or fields.
+    // NOTE: If the passed in `text_generaor` is not actually the current
+    // `TextGenerator`, then no output will be produced.
+    void PrintMessage(const Message& message,
+                      BaseTextGenerator* generator) const;
+
    private:
     friend std::string Message::DebugString() const;
     friend std::string Message::ShortDebugString() const;
@@ -404,6 +441,7 @@
     // Forward declaration of an internal class used to print the text
     // output to the OutputStream (see text_format.cc for implementation).
     class TextGenerator;
+    using MarkerToken = BaseTextGenerator::MarkerToken;
 
     // Forward declaration of an internal class used to print field values for
     // DebugString APIs (see text_format.cc for implementation).
@@ -417,40 +455,40 @@
 
     // Internal Print method, used for writing to the OutputStream via
     // the TextGenerator class.
-    void Print(const Message& message, TextGenerator* generator) const;
+    void Print(const Message& message, BaseTextGenerator* generator) const;
 
     // Print a single field.
     void PrintField(const Message& message, const Reflection* reflection,
                     const FieldDescriptor* field,
-                    TextGenerator* generator) const;
+                    BaseTextGenerator* generator) const;
 
     // Print a repeated primitive field in short form.
     void PrintShortRepeatedField(const Message& message,
                                  const Reflection* reflection,
                                  const FieldDescriptor* field,
-                                 TextGenerator* generator) const;
+                                 BaseTextGenerator* generator) const;
 
     // Print the name of a field -- i.e. everything that comes before the
     // ':' for a single name/value pair.
     void PrintFieldName(const Message& message, int field_index,
                         int field_count, const Reflection* reflection,
                         const FieldDescriptor* field,
-                        TextGenerator* generator) const;
+                        BaseTextGenerator* generator) const;
 
     // Outputs a textual representation of the value of the field supplied on
     // the message supplied or the default value if not set.
     void PrintFieldValue(const Message& message, const Reflection* reflection,
                          const FieldDescriptor* field, int index,
-                         TextGenerator* generator) const;
+                         BaseTextGenerator* generator) const;
 
     // Print the fields in an UnknownFieldSet.  They are printed by tag number
     // only.  Embedded messages are heuristically identified by attempting to
     // parse them (subject to the recursion budget).
     void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
-                            TextGenerator* generator,
+                            BaseTextGenerator* generator,
                             int recursion_budget) const;
 
-    bool PrintAny(const Message& message, TextGenerator* generator) const;
+    bool PrintAny(const Message& message, BaseTextGenerator* generator) const;
 
     const FastFieldValuePrinter* GetFieldPrinter(
         const FieldDescriptor* field) const {
diff --git a/src/google/protobuf/text_format_unittest.cc b/src/google/protobuf/text_format_unittest.cc
index 6dcc044..1bca2df 100644
--- a/src/google/protobuf/text_format_unittest.cc
+++ b/src/google/protobuf/text_format_unittest.cc
@@ -780,15 +780,24 @@
   ~CustomNestedMessagePrinter() override {}
   void Print(const Message& message, bool single_line_mode,
              TextFormat::BaseTextGenerator* generator) const override {
-    generator->PrintLiteral("custom");
+    generator->PrintLiteral("// custom\n");
+    if (printer_ != nullptr) {
+      printer_->PrintMessage(message, generator);
+    }
   }
+
+  void SetPrinter(TextFormat::Printer* printer) { printer_ = printer; }
+
+ private:
+  TextFormat::Printer* printer_ = nullptr;
 };
 
 TEST_F(TextFormatTest, CustomMessagePrinter) {
   TextFormat::Printer printer;
+  auto* custom_printer = new CustomNestedMessagePrinter;
   printer.RegisterMessagePrinter(
       unittest::TestAllTypes::NestedMessage::default_instance().descriptor(),
-      new CustomNestedMessagePrinter);
+      custom_printer);
 
   unittest::TestAllTypes message;
   std::string text;
@@ -797,7 +806,11 @@
 
   message.mutable_optional_nested_message()->set_bb(1);
   EXPECT_TRUE(printer.PrintToString(message, &text));
-  EXPECT_EQ("optional_nested_message {\n  custom}\n", text);
+  EXPECT_EQ("optional_nested_message {\n  // custom\n}\n", text);
+
+  custom_printer->SetPrinter(&printer);
+  EXPECT_TRUE(printer.PrintToString(message, &text));
+  EXPECT_EQ("optional_nested_message {\n  // custom\n  bb: 1\n}\n", text);
 }
 
 TEST_F(TextFormatTest, ParseBasic) {