Improve Init performance by avoiding recursion analysis.

The new approach works as follows. The initialization is performed up-to a certain depth. However, this makes it hard for the mutator to select fields in deeper layers. To get around this, during mutation, we assign a weight to non-present proto fields based on their estimated number of sub-fields.

PiperOrigin-RevId: 712897465
diff --git a/domain_tests/arbitrary_domains_protobuf_test.cc b/domain_tests/arbitrary_domains_protobuf_test.cc
index 8ec6150..d44d1e4 100644
--- a/domain_tests/arbitrary_domains_protobuf_test.cc
+++ b/domain_tests/arbitrary_domains_protobuf_test.cc
@@ -133,7 +133,7 @@
       },
       Gt(0));
   EXPECT_THAT(
-      GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 1, 5000),
+      GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 10, 100),
       AllOf(Contains(has_ext), Contains(has_rep_ext)));
 }
 
@@ -641,25 +641,28 @@
   T v;
   auto corpus_v_uninitialized = domain.FromValue(v);
   EXPECT_TRUE(corpus_v_uninitialized != std::nullopt);
-  EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()), 26);
+  int fields_count = 28;
+  EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()),
+            fields_count + /*estimated subfields count*/ 4);
   v.set_allocated_subproto(new SubT());
   auto corpus_v_initizalize_one_optional_proto = domain.FromValue(v);
   EXPECT_TRUE(corpus_v_initizalize_one_optional_proto != std::nullopt);
-  EXPECT_EQ(domain.CountNumberOfFields(
-                corpus_v_initizalize_one_optional_proto.value()),
-            28);
+  EXPECT_EQ(
+      domain.CountNumberOfFields(
+          corpus_v_initizalize_one_optional_proto.value()),
+      fields_count + /*subfields count*/ 2 + /*estimated subfields count*/ 2);
   v.add_rep_subproto();
   auto corpus_v_initizalize_one_repeated_proto_1 = domain.FromValue(v);
   EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_1 != std::nullopt);
   EXPECT_EQ(domain.CountNumberOfFields(
                 corpus_v_initizalize_one_repeated_proto_1.value()),
-            30);
+            fields_count + /*subfields count*/ 4);
   v.add_rep_subproto();
   auto corpus_v_initizalize_one_repeated_proto_2 = domain.FromValue(v);
   EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_2 != std::nullopt);
   EXPECT_EQ(domain.CountNumberOfFields(
                 corpus_v_initizalize_one_repeated_proto_2.value()),
-            32);
+            fields_count + /*subfields count*/ 6);
 }
 
 auto FieldNameHasSubstr(absl::string_view field_name) {
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h
index 279bf2d..b79bcbc 100644
--- a/fuzztest/internal/domains/protobuf_domain_impl.h
+++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -311,6 +311,10 @@
     return max;
   }
 
+  void IncrementDepth() { ++depth_; }
+  void ResetDepth() { depth_ = 0; }
+  int depth() const { return depth_; }
+
  private:
   template <typename T>
   struct FilterToValue {
@@ -345,6 +349,7 @@
   std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
   std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
   std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
+  int depth_ = 0;
 
 #define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp)                           \
  private:                                                                      \
@@ -453,6 +458,15 @@
     unset_oneof_fields_ = other.unset_oneof_fields_;
   }
 
+  // We initialize to at most depth 5. This gives a good balanced between
+  // initialization speed and coverage.
+  bool IsTooDeep(absl::BitGenRef prng) {
+    constexpr int kMaxInitDepth = 5;
+    int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng,
+                                         int64_t{1}, kMaxInitDepth);
+    return policy_.depth() >= acceptable_depth;
+  }
+
   corpus_type Init(absl::BitGenRef prng) {
     if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
     FUZZTEST_INTERNAL_CHECK(
@@ -466,23 +480,21 @@
     for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
       if (auto* oneof = field->containing_oneof()) {
         if (!oneof_to_field.contains(oneof->index())) {
-          oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
-              oneof, prng,
-              /*non_recursive_only=*/customized_fields_.empty());
+          oneof_to_field[oneof->index()] =
+              SelectAFieldIndexInOneof(oneof, prng);
         }
         if (oneof_to_field[oneof->index()] != field->index()) continue;
-      } else if (!IsRequired(field) && customized_fields_.empty() &&
-                 IsFieldRecursive(field)) {
-        // We avoid initializing non-required recursive fields by default (if
-        // they are not explicitly customized). Otherwise, the initialization
-        // may never terminate. If a proto has only non-required recursive
-        // fields, the initialization will be deterministic, which violates the
-        // assumption on domain Init. However, such cases should be extremely
-        // rare and breaking the assumption would not have severe consequences.
-        continue;
+      } else if (!IsRequired(field) && customized_fields_.empty()) {
+        // We avoid initializing non-required fields after some depth if they
+        // are not explicitly customized). Otherwise, the initialization may
+        // never terminate.
+        if (field->message_type() != nullptr && IsTooDeep(prng)) continue;
       }
       VisitProtobufField(field, InitializeVisitor{prng, *this, val});
     }
+    // Depth is used only for initialization. When mutating, we should start
+    // from the beginning.
+    policy_.ResetDepth();
     return val;
   }
 
@@ -598,7 +610,10 @@
 
       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
         auto val_it = val.find(field->number());
-        if (val_it == val.end()) continue;
+        if (val_it == val.end()) {
+          total_weight += EstimatedMutationWeight(field);
+          continue;
+        }
         if (field->is_repeated()) {
           total_weight +=
               GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
@@ -635,7 +650,13 @@
 
       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
         auto val_it = val.find(field->number());
-        if (val_it == val.end()) continue;
+        if (val_it == val.end()) {
+          field_counter += EstimatedMutationWeight(field);
+          if (field_counter >= selected_field_index) continue;
+          VisitProtobufField(
+              field, MutateVisitor{prng, metadata, only_shrink, *this, val});
+          return field_counter;
+        }
         if (field->is_repeated()) {
           field_counter +=
               GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
@@ -933,12 +954,13 @@
 
   template <typename OneofDescriptor>
   int SelectAFieldIndexInOneof(const OneofDescriptor* oneof,
-                               absl::BitGenRef prng, bool non_recursive_only) {
+                               absl::BitGenRef prng) {
     std::vector<int> fields;
     for (int i = 0; i < oneof->field_count(); ++i) {
       OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
       if (policy == OptionalPolicy::kAlwaysNull) continue;
-      if (non_recursive_only && IsFieldRecursive(oneof->field(i))) continue;
+      // We avoid initializing non-required oneof fields after some depth.
+      if (policy == OptionalPolicy::kWithNull && IsTooDeep(prng)) continue;
       fields.push_back(i);
     }
     if (fields.empty()) {  // This can happen if all fields are unset.
@@ -1325,13 +1347,16 @@
     return descriptor->field_count() + extensions.size();
   }
 
-  static auto GetProtobufFields(const Descriptor* descriptor) {
+  static auto GetProtobufFields(const Descriptor* descriptor,
+                                bool include_extensions = true) {
     std::vector<const FieldDescriptor*> fields;
     fields.reserve(descriptor->field_count());
     for (int i = 0; i < descriptor->field_count(); ++i) {
       fields.push_back(descriptor->field(i));
     }
-    descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+    if (include_extensions) {
+      descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+    }
     return fields;
   }
 
@@ -1599,6 +1624,7 @@
               field->message_type()),
           use_lazy_initialization_);
       result.SetPolicy(policy_);
+      result.GetPolicy().IncrementDepth();
       return Domain<std::unique_ptr<Message>>(result);
     } else {
       return Domain<T>(ArbitraryImpl<T>());
@@ -1673,6 +1699,46 @@
                             /*consider_non_terminating_recursions=*/false);
   }
 
+  // Returns an estimate for the number of sub-fields of a given field to be
+  // used as a heuristic. For simplicity, the recursive sub-fields are ignored.
+  // And for efficiency, proto extensions as well as protos with large number of
+  // fields are ignored.
+  static int64_t EstimatedMutationWeight(const FieldDescriptor* field) {
+    auto descriptor = field->message_type();
+    if (!descriptor) return 1;  // non-message fields
+    static absl::flat_hash_map<std::string, int> kFieldCountCache;
+    if (auto it = kFieldCountCache.find(descriptor->full_name());
+        it != kFieldCountCache.end()) {
+      return it->second;
+    }
+    absl::flat_hash_set<decltype(descriptor)> parents;
+    int result = EstimatedMutationWeight(descriptor, parents);
+    kFieldCountCache.insert({descriptor->full_name(), result});
+    return result;
+  }
+
+  template <typename Descriptor>
+  static int EstimatedMutationWeight(
+      const Descriptor* descriptor,
+      absl::flat_hash_set<const Descriptor*>& parents) {
+    int field_count = 0;
+    if (parents.contains(descriptor)) return field_count;
+    parents.insert(descriptor);
+    for (const FieldDescriptor* field :
+         GetProtobufFields(descriptor, /*include_extensions=*/false)) {
+      const auto* child = field->message_type();
+      field_count += child ? EstimatedMutationWeight(child, parents) : 1;
+      constexpr int kMaxFieldCount = 1000;
+      // Avoid expensive estimation operations.
+      if (field_count > kMaxFieldCount) {
+        field_count = kMaxFieldCount;
+        break;
+      }
+    }
+    parents.erase(descriptor);
+    return field_count;
+  }
+
   bool IsOneofRecursive(const OneofDescriptor* oneof,
                         absl::flat_hash_set<const Descriptor*>& parents,
                         const ProtoPolicy<Message>& policy,