Store mutation weights in the corpus_type for protobuf domains.

PiperOrigin-RevId: 712897464
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h
index 279bf2d..c00e2fb 100644
--- a/fuzztest/internal/domains/protobuf_domain_impl.h
+++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -311,6 +311,10 @@
     return max;
   }
 
+  void IncrementDepth() { ++depth_; }
+
+  int depth() const { return depth_; }
+
  private:
   template <typename T>
   struct FilterToValue {
@@ -345,6 +349,7 @@
   std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
   std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
   std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
+  int depth_ = 0;
 
 #define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp)                           \
  private:                                                                      \
@@ -453,6 +458,13 @@
     unset_oneof_fields_ = other.unset_oneof_fields_;
   }
 
+  bool ShouldMutate(const FieldDescriptor* field) const {
+    // When all fields in a oneof ae unset, there could be a high chance that
+    // the mutation will lead to noop. So, we exclude them from mutation.
+    return !field->containing_oneof() ||
+           GetOneofFieldPolicy(field) != OptionalPolicy::kAlwaysNull;
+  }
+
   corpus_type Init(absl::BitGenRef prng) {
     if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
     FUZZTEST_INTERNAL_CHECK(
@@ -463,7 +475,10 @@
     absl::flat_hash_map<int, int> oneof_to_field;
 
     // TODO(b/241124202): Use a valid proto with minimum size.
+    int64_t total_mutation_weight = 0;
     for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
+      int64_t weight = EstimatedMutationWeight(field);
+      if (ShouldMutate(field)) total_mutation_weight += weight;
       if (auto* oneof = field->containing_oneof()) {
         if (!oneof_to_field.contains(oneof->index())) {
           oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
@@ -472,17 +487,33 @@
         }
         if (oneof_to_field[oneof->index()] != field->index()) continue;
       } else if (!IsRequired(field) && customized_fields_.empty() &&
-                 IsFieldRecursive(field)) {
-        // We avoid initializing non-required recursive fields by default (if
-        // they are not explicitly customized). Otherwise, the initialization
-        // may never terminate. If a proto has only non-required recursive
-        // fields, the initialization will be deterministic, which violates the
-        // assumption on domain Init. However, such cases should be extremely
-        // rare and breaking the assumption would not have severe consequences.
-        continue;
+                 field->message_type() != nullptr) {
+        // We avoid initializing non-required fields after some depth if they
+        // are not explicitly customized). Otherwise, the initialization may
+        // never terminate. We initialize to at most depth 5. This gives a good
+        // balanced between initialization speed and coverage.
+        constexpr int kMaxInitDepth = 5;
+        int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng,
+                                             int64_t{1}, kMaxInitDepth);
+        if (policy_.depth() >= acceptable_depth) continue;
       }
       VisitProtobufField(field, InitializeVisitor{prng, *this, val});
+      if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
+          ShouldMutate(field)) {
+        total_mutation_weight -= weight;
+        auto val_it = val.find(field->number());
+        if (field->is_repeated()) {
+          total_mutation_weight +=
+              GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
+                  val_it->second);
+        } else {
+          total_mutation_weight +=
+              GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
+                  val_it->second);
+        }
+      }
     }
+    SetMutationWeight(val, total_mutation_weight);
     return val;
   }
 
@@ -503,6 +534,7 @@
     value_type out(prototype_.Get()->New());
 
     for (auto& [number, data] : value) {
+      if (IsMetadataEntry(number)) continue;
       auto* field = GetField(number);
       VisitProtobufField(field, GetValueVisitor{*out, *this, data});
     }
@@ -585,32 +617,7 @@
   }
 
   uint64_t CountNumberOfFields(const corpus_type& val) {
-    uint64_t total_weight = 0;
-    auto descriptor = prototype_.Get()->GetDescriptor();
-    if (GetFieldCount(descriptor) == 0) return total_weight;
-
-    for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
-      if (field->containing_oneof() &&
-          GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
-        continue;
-      }
-      ++total_weight;
-
-      if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
-        auto val_it = val.find(field->number());
-        if (val_it == val.end()) continue;
-        if (field->is_repeated()) {
-          total_weight +=
-              GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
-                  val_it->second);
-        } else {
-          total_weight +=
-              GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
-                  val_it->second);
-        }
-      }
-    }
-    return total_weight;
+    return GetMutationWeight(val);
   }
 
   uint64_t MutateSelectedField(
@@ -620,12 +627,11 @@
     uint64_t field_counter = 0;
     auto descriptor = prototype_.Get()->GetDescriptor();
     if (GetFieldCount(descriptor) == 0) return field_counter;
+    int64_t fields_count = CountNumberOfFields(val);
+    if (fields_count < selected_field_index) return fields_count;
 
     for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
-      if (field->containing_oneof() &&
-          GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
-        continue;
-      }
+      if (!ShouldMutate(field)) continue;
       ++field_counter;
       if (field_counter == selected_field_index) {
         VisitProtobufField(
@@ -634,8 +640,15 @@
       }
 
       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+        --field_counter;
         auto val_it = val.find(field->number());
-        if (val_it == val.end()) continue;
+        if (val_it == val.end()) {
+          field_counter += EstimatedMutationWeight(field);
+          if (field_counter >= selected_field_index) continue;
+          VisitProtobufField(
+              field, MutateVisitor{prng, metadata, only_shrink, *this, val});
+          return field_counter;
+        }
         if (field->is_repeated()) {
           field_counter +=
               GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
@@ -964,13 +977,18 @@
 
       if (is_present) {
         // Mutate the element
+        AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
         domain.Mutate(it->second, prng, metadata, only_shrink);
+        AddMutationWeight(val, domain.CountNumberOfFields(it->second));
         return;
       }
 
       // Add the element
       if (!only_shrink) {
+        AddMutationWeight(val, -EstimatedMutationWeight(field));
         InitializeFieldValue<T>(prng, self, field, val);
+        AddMutationWeight(val,
+                          domain.CountNumberOfFields(val[field->number()]));
       }
     }
 
@@ -982,7 +1000,9 @@
 
       if (!is_present) {
         if (!only_shrink) {
-          val[field->number()] = domain.Init(prng);
+          AddMutationWeight(val, -EstimatedMutationWeight(field));
+          auto& field_val = val[field->number()] = domain.Init(prng);
+          AddMutationWeight(val, domain.CountNumberOfFields(field_val));
         }
       } else if (field->is_map()) {
         // field of synthetic messages of the form:
@@ -1000,6 +1020,7 @@
         // TODO(b/231212420): Improve mutation to not add duplicate keys on the
         // first place. The current hack is very inefficient.
         // Switch the inner domain for maps to use flat_hash_map instead.
+        AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
         corpus_type corpus_copy;
         auto& copy = corpus_copy[field->number()] = it->second;
         domain.Mutate(copy, prng, metadata, only_shrink);
@@ -1012,8 +1033,11 @@
           // The number of entries is the same, so accept the change.
           it->second = std::move(copy);
         }
+        AddMutationWeight(val, domain.CountNumberOfFields(it->second));
       } else {
+        AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
         domain.Mutate(it->second, prng, metadata, only_shrink);
+        AddMutationWeight(val, domain.CountNumberOfFields(it->second));
       }
     }
   };
@@ -1325,13 +1349,16 @@
     return descriptor->field_count() + extensions.size();
   }
 
-  static auto GetProtobufFields(const Descriptor* descriptor) {
+  static auto GetProtobufFields(const Descriptor* descriptor,
+                                bool include_extensions = true) {
     std::vector<const FieldDescriptor*> fields;
     fields.reserve(descriptor->field_count());
     for (int i = 0; i < descriptor->field_count(); ++i) {
       fields.push_back(descriptor->field(i));
     }
-    descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+    if (include_extensions) {
+      descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+    }
     return fields;
   }
 
@@ -1599,6 +1626,7 @@
               field->message_type()),
           use_lazy_initialization_);
       result.SetPolicy(policy_);
+      result.GetPolicy().IncrementDepth();
       return Domain<std::unique_ptr<Message>>(result);
     } else {
       return Domain<T>(ArbitraryImpl<T>());
@@ -1673,6 +1701,70 @@
                             /*consider_non_terminating_recursions=*/false);
   }
 
+  // Returns an estimate for the number of sub-fields of a given field to be
+  // used as a heuristic. For simplicity, the recursive sub-fields are ignored.
+  // And for efficiency, proto extensions as well as protos with large number of
+  // fields are ignored.
+  static int64_t EstimatedMutationWeight(const FieldDescriptor* field) {
+    auto descriptor = field->message_type();
+    if (!descriptor) return 1;  // non-message fields
+    static absl::flat_hash_map<std::string, int> kFieldCountCache;
+    if (auto it = kFieldCountCache.find(descriptor->full_name());
+        it != kFieldCountCache.end()) {
+      return it->second;
+    }
+    absl::flat_hash_set<decltype(descriptor)> parents;
+    int result = EstimatedMutationWeight(descriptor, parents);
+    kFieldCountCache.insert({descriptor->full_name(), result});
+    return result;
+  }
+
+  template <typename Descriptor>
+  static int EstimatedMutationWeight(
+      const Descriptor* descriptor,
+      absl::flat_hash_set<const Descriptor*>& parents) {
+    int field_count = 0;
+    if (parents.contains(descriptor)) return field_count;
+    parents.insert(descriptor);
+    for (const FieldDescriptor* field :
+         GetProtobufFields(descriptor, /*include_extensions=*/false)) {
+      const auto* child = field->message_type();
+      field_count += child ? EstimatedMutationWeight(child, parents) : 1;
+      constexpr int kMaxFieldCount = 1000;
+      // Avoid expensive estimation operations.
+      if (field_count > kMaxFieldCount) {
+        field_count = kMaxFieldCount;
+        break;
+      }
+    }
+    parents.erase(descriptor);
+    return field_count;
+  }
+
+  static bool IsMetadataEntry(int index) {
+    // Used for storing the mutation weight.
+    return index == -1;
+  }
+
+  // Mutation weight for a proto is defined as the number of fields in the
+  // proto. All fields in a oneof are counted. For sub-messages, if the
+  // sub-message is set, the weight is the mutation weight of the sub-message.
+  // Otherwise, the weight is the predicted number of fields.
+  static void SetMutationWeight(corpus_type& val, int64_t field_count) {
+    val[-1] = GenericDomainCorpusType(std::in_place_type<int64_t>, field_count);
+  }
+
+  static int64_t GetMutationWeight(const corpus_type& val) {
+    auto it = val.find(-1);
+    FUZZTEST_INTERNAL_CHECK_PRECONDITION(it != val.end(),
+                                         "Invalid corpus value!");
+    return it->second.template GetAs<int64_t>();
+  }
+
+  static void AddMutationWeight(corpus_type& val, int64_t to_add) {
+    SetMutationWeight(val, GetMutationWeight(val) + to_add);
+  }
+
   bool IsOneofRecursive(const OneofDescriptor* oneof,
                         absl::flat_hash_set<const Descriptor*>& parents,
                         const ProtoPolicy<Message>& policy,