Store mutation weights in the corpus_type for protobuf domains.
PiperOrigin-RevId: 712897464
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h
index 279bf2d..c00e2fb 100644
--- a/fuzztest/internal/domains/protobuf_domain_impl.h
+++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -311,6 +311,10 @@
return max;
}
+ void IncrementDepth() { ++depth_; }
+
+ int depth() const { return depth_; }
+
private:
template <typename T>
struct FilterToValue {
@@ -345,6 +349,7 @@
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
+ int depth_ = 0;
#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
@@ -453,6 +458,13 @@
unset_oneof_fields_ = other.unset_oneof_fields_;
}
+ bool ShouldMutate(const FieldDescriptor* field) const {
+ // When all fields in a oneof ae unset, there could be a high chance that
+ // the mutation will lead to noop. So, we exclude them from mutation.
+ return !field->containing_oneof() ||
+ GetOneofFieldPolicy(field) != OptionalPolicy::kAlwaysNull;
+ }
+
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
@@ -463,7 +475,10 @@
absl::flat_hash_map<int, int> oneof_to_field;
// TODO(b/241124202): Use a valid proto with minimum size.
+ int64_t total_mutation_weight = 0;
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
+ int64_t weight = EstimatedMutationWeight(field);
+ if (ShouldMutate(field)) total_mutation_weight += weight;
if (auto* oneof = field->containing_oneof()) {
if (!oneof_to_field.contains(oneof->index())) {
oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
@@ -472,17 +487,33 @@
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!IsRequired(field) && customized_fields_.empty() &&
- IsFieldRecursive(field)) {
- // We avoid initializing non-required recursive fields by default (if
- // they are not explicitly customized). Otherwise, the initialization
- // may never terminate. If a proto has only non-required recursive
- // fields, the initialization will be deterministic, which violates the
- // assumption on domain Init. However, such cases should be extremely
- // rare and breaking the assumption would not have severe consequences.
- continue;
+ field->message_type() != nullptr) {
+ // We avoid initializing non-required fields after some depth if they
+ // are not explicitly customized). Otherwise, the initialization may
+ // never terminate. We initialize to at most depth 5. This gives a good
+ // balanced between initialization speed and coverage.
+ constexpr int kMaxInitDepth = 5;
+ int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng,
+ int64_t{1}, kMaxInitDepth);
+ if (policy_.depth() >= acceptable_depth) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
+ if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
+ ShouldMutate(field)) {
+ total_mutation_weight -= weight;
+ auto val_it = val.find(field->number());
+ if (field->is_repeated()) {
+ total_mutation_weight +=
+ GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
+ val_it->second);
+ } else {
+ total_mutation_weight +=
+ GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
+ val_it->second);
+ }
+ }
}
+ SetMutationWeight(val, total_mutation_weight);
return val;
}
@@ -503,6 +534,7 @@
value_type out(prototype_.Get()->New());
for (auto& [number, data] : value) {
+ if (IsMetadataEntry(number)) continue;
auto* field = GetField(number);
VisitProtobufField(field, GetValueVisitor{*out, *this, data});
}
@@ -585,32 +617,7 @@
}
uint64_t CountNumberOfFields(const corpus_type& val) {
- uint64_t total_weight = 0;
- auto descriptor = prototype_.Get()->GetDescriptor();
- if (GetFieldCount(descriptor) == 0) return total_weight;
-
- for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
- if (field->containing_oneof() &&
- GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
- continue;
- }
- ++total_weight;
-
- if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- auto val_it = val.find(field->number());
- if (val_it == val.end()) continue;
- if (field->is_repeated()) {
- total_weight +=
- GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
- val_it->second);
- } else {
- total_weight +=
- GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
- val_it->second);
- }
- }
- }
- return total_weight;
+ return GetMutationWeight(val);
}
uint64_t MutateSelectedField(
@@ -620,12 +627,11 @@
uint64_t field_counter = 0;
auto descriptor = prototype_.Get()->GetDescriptor();
if (GetFieldCount(descriptor) == 0) return field_counter;
+ int64_t fields_count = CountNumberOfFields(val);
+ if (fields_count < selected_field_index) return fields_count;
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
- if (field->containing_oneof() &&
- GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
- continue;
- }
+ if (!ShouldMutate(field)) continue;
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(
@@ -634,8 +640,15 @@
}
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ --field_counter;
auto val_it = val.find(field->number());
- if (val_it == val.end()) continue;
+ if (val_it == val.end()) {
+ field_counter += EstimatedMutationWeight(field);
+ if (field_counter >= selected_field_index) continue;
+ VisitProtobufField(
+ field, MutateVisitor{prng, metadata, only_shrink, *this, val});
+ return field_counter;
+ }
if (field->is_repeated()) {
field_counter +=
GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
@@ -964,13 +977,18 @@
if (is_present) {
// Mutate the element
+ AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
domain.Mutate(it->second, prng, metadata, only_shrink);
+ AddMutationWeight(val, domain.CountNumberOfFields(it->second));
return;
}
// Add the element
if (!only_shrink) {
+ AddMutationWeight(val, -EstimatedMutationWeight(field));
InitializeFieldValue<T>(prng, self, field, val);
+ AddMutationWeight(val,
+ domain.CountNumberOfFields(val[field->number()]));
}
}
@@ -982,7 +1000,9 @@
if (!is_present) {
if (!only_shrink) {
- val[field->number()] = domain.Init(prng);
+ AddMutationWeight(val, -EstimatedMutationWeight(field));
+ auto& field_val = val[field->number()] = domain.Init(prng);
+ AddMutationWeight(val, domain.CountNumberOfFields(field_val));
}
} else if (field->is_map()) {
// field of synthetic messages of the form:
@@ -1000,6 +1020,7 @@
// TODO(b/231212420): Improve mutation to not add duplicate keys on the
// first place. The current hack is very inefficient.
// Switch the inner domain for maps to use flat_hash_map instead.
+ AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
corpus_type corpus_copy;
auto& copy = corpus_copy[field->number()] = it->second;
domain.Mutate(copy, prng, metadata, only_shrink);
@@ -1012,8 +1033,11 @@
// The number of entries is the same, so accept the change.
it->second = std::move(copy);
}
+ AddMutationWeight(val, domain.CountNumberOfFields(it->second));
} else {
+ AddMutationWeight(val, -domain.CountNumberOfFields(it->second));
domain.Mutate(it->second, prng, metadata, only_shrink);
+ AddMutationWeight(val, domain.CountNumberOfFields(it->second));
}
}
};
@@ -1325,13 +1349,16 @@
return descriptor->field_count() + extensions.size();
}
- static auto GetProtobufFields(const Descriptor* descriptor) {
+ static auto GetProtobufFields(const Descriptor* descriptor,
+ bool include_extensions = true) {
std::vector<const FieldDescriptor*> fields;
fields.reserve(descriptor->field_count());
for (int i = 0; i < descriptor->field_count(); ++i) {
fields.push_back(descriptor->field(i));
}
- descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+ if (include_extensions) {
+ descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+ }
return fields;
}
@@ -1599,6 +1626,7 @@
field->message_type()),
use_lazy_initialization_);
result.SetPolicy(policy_);
+ result.GetPolicy().IncrementDepth();
return Domain<std::unique_ptr<Message>>(result);
} else {
return Domain<T>(ArbitraryImpl<T>());
@@ -1673,6 +1701,70 @@
/*consider_non_terminating_recursions=*/false);
}
+ // Returns an estimate for the number of sub-fields of a given field to be
+ // used as a heuristic. For simplicity, the recursive sub-fields are ignored.
+ // And for efficiency, proto extensions as well as protos with large number of
+ // fields are ignored.
+ static int64_t EstimatedMutationWeight(const FieldDescriptor* field) {
+ auto descriptor = field->message_type();
+ if (!descriptor) return 1; // non-message fields
+ static absl::flat_hash_map<std::string, int> kFieldCountCache;
+ if (auto it = kFieldCountCache.find(descriptor->full_name());
+ it != kFieldCountCache.end()) {
+ return it->second;
+ }
+ absl::flat_hash_set<decltype(descriptor)> parents;
+ int result = EstimatedMutationWeight(descriptor, parents);
+ kFieldCountCache.insert({descriptor->full_name(), result});
+ return result;
+ }
+
+ template <typename Descriptor>
+ static int EstimatedMutationWeight(
+ const Descriptor* descriptor,
+ absl::flat_hash_set<const Descriptor*>& parents) {
+ int field_count = 0;
+ if (parents.contains(descriptor)) return field_count;
+ parents.insert(descriptor);
+ for (const FieldDescriptor* field :
+ GetProtobufFields(descriptor, /*include_extensions=*/false)) {
+ const auto* child = field->message_type();
+ field_count += child ? EstimatedMutationWeight(child, parents) : 1;
+ constexpr int kMaxFieldCount = 1000;
+ // Avoid expensive estimation operations.
+ if (field_count > kMaxFieldCount) {
+ field_count = kMaxFieldCount;
+ break;
+ }
+ }
+ parents.erase(descriptor);
+ return field_count;
+ }
+
+ static bool IsMetadataEntry(int index) {
+ // Used for storing the mutation weight.
+ return index == -1;
+ }
+
+ // Mutation weight for a proto is defined as the number of fields in the
+ // proto. All fields in a oneof are counted. For sub-messages, if the
+ // sub-message is set, the weight is the mutation weight of the sub-message.
+ // Otherwise, the weight is the predicted number of fields.
+ static void SetMutationWeight(corpus_type& val, int64_t field_count) {
+ val[-1] = GenericDomainCorpusType(std::in_place_type<int64_t>, field_count);
+ }
+
+ static int64_t GetMutationWeight(const corpus_type& val) {
+ auto it = val.find(-1);
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(it != val.end(),
+ "Invalid corpus value!");
+ return it->second.template GetAs<int64_t>();
+ }
+
+ static void AddMutationWeight(corpus_type& val, int64_t to_add) {
+ SetMutationWeight(val, GetMutationWeight(val) + to_add);
+ }
+
bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,