Improve Init performance by avoiding recursion analysis.
The new approach works as follows. The initialization is performed up-to a certain depth. However, this makes it hard for the mutator to select fields in deeper layers. To get around this, during mutation, we assign a weight to non-present proto fields based on their estimated number of sub-fields.
PiperOrigin-RevId: 712897465
diff --git a/domain_tests/arbitrary_domains_protobuf_test.cc b/domain_tests/arbitrary_domains_protobuf_test.cc
index 8ec6150..d44d1e4 100644
--- a/domain_tests/arbitrary_domains_protobuf_test.cc
+++ b/domain_tests/arbitrary_domains_protobuf_test.cc
@@ -133,7 +133,7 @@
},
Gt(0));
EXPECT_THAT(
- GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 1, 5000),
+ GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 10, 100),
AllOf(Contains(has_ext), Contains(has_rep_ext)));
}
@@ -641,25 +641,28 @@
T v;
auto corpus_v_uninitialized = domain.FromValue(v);
EXPECT_TRUE(corpus_v_uninitialized != std::nullopt);
- EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()), 26);
+ int fields_count = 28;
+ EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()),
+ fields_count + /*estimated subfields count*/ 4);
v.set_allocated_subproto(new SubT());
auto corpus_v_initizalize_one_optional_proto = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_optional_proto != std::nullopt);
- EXPECT_EQ(domain.CountNumberOfFields(
- corpus_v_initizalize_one_optional_proto.value()),
- 28);
+ EXPECT_EQ(
+ domain.CountNumberOfFields(
+ corpus_v_initizalize_one_optional_proto.value()),
+ fields_count + /*subfields count*/ 2 + /*estimated subfields count*/ 2);
v.add_rep_subproto();
auto corpus_v_initizalize_one_repeated_proto_1 = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_1 != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(
corpus_v_initizalize_one_repeated_proto_1.value()),
- 30);
+ fields_count + /*subfields count*/ 4);
v.add_rep_subproto();
auto corpus_v_initizalize_one_repeated_proto_2 = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_2 != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(
corpus_v_initizalize_one_repeated_proto_2.value()),
- 32);
+ fields_count + /*subfields count*/ 6);
}
auto FieldNameHasSubstr(absl::string_view field_name) {
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h
index 279bf2d..b79bcbc 100644
--- a/fuzztest/internal/domains/protobuf_domain_impl.h
+++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -311,6 +311,10 @@
return max;
}
+ void IncrementDepth() { ++depth_; }
+ void ResetDepth() { depth_ = 0; }
+ int depth() const { return depth_; }
+
private:
template <typename T>
struct FilterToValue {
@@ -345,6 +349,7 @@
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
+ int depth_ = 0;
#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
@@ -453,6 +458,15 @@
unset_oneof_fields_ = other.unset_oneof_fields_;
}
+ // We initialize to at most depth 5. This gives a good balanced between
+ // initialization speed and coverage.
+ bool IsTooDeep(absl::BitGenRef prng) {
+ constexpr int kMaxInitDepth = 5;
+ int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng,
+ int64_t{1}, kMaxInitDepth);
+ return policy_.depth() >= acceptable_depth;
+ }
+
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
@@ -466,23 +480,21 @@
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (auto* oneof = field->containing_oneof()) {
if (!oneof_to_field.contains(oneof->index())) {
- oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
- oneof, prng,
- /*non_recursive_only=*/customized_fields_.empty());
+ oneof_to_field[oneof->index()] =
+ SelectAFieldIndexInOneof(oneof, prng);
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
- } else if (!IsRequired(field) && customized_fields_.empty() &&
- IsFieldRecursive(field)) {
- // We avoid initializing non-required recursive fields by default (if
- // they are not explicitly customized). Otherwise, the initialization
- // may never terminate. If a proto has only non-required recursive
- // fields, the initialization will be deterministic, which violates the
- // assumption on domain Init. However, such cases should be extremely
- // rare and breaking the assumption would not have severe consequences.
- continue;
+ } else if (!IsRequired(field) && customized_fields_.empty()) {
+ // We avoid initializing non-required fields after some depth if they
+ // are not explicitly customized). Otherwise, the initialization may
+ // never terminate.
+ if (field->message_type() != nullptr && IsTooDeep(prng)) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
+ // Depth is used only for initialization. When mutating, we should start
+ // from the beginning.
+ policy_.ResetDepth();
return val;
}
@@ -598,7 +610,10 @@
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
- if (val_it == val.end()) continue;
+ if (val_it == val.end()) {
+ total_weight += EstimatedMutationWeight(field);
+ continue;
+ }
if (field->is_repeated()) {
total_weight +=
GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
@@ -635,7 +650,13 @@
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
- if (val_it == val.end()) continue;
+ if (val_it == val.end()) {
+ field_counter += EstimatedMutationWeight(field);
+ if (field_counter >= selected_field_index) continue;
+ VisitProtobufField(
+ field, MutateVisitor{prng, metadata, only_shrink, *this, val});
+ return field_counter;
+ }
if (field->is_repeated()) {
field_counter +=
GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
@@ -933,12 +954,13 @@
template <typename OneofDescriptor>
int SelectAFieldIndexInOneof(const OneofDescriptor* oneof,
- absl::BitGenRef prng, bool non_recursive_only) {
+ absl::BitGenRef prng) {
std::vector<int> fields;
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
- if (non_recursive_only && IsFieldRecursive(oneof->field(i))) continue;
+ // We avoid initializing non-required oneof fields after some depth.
+ if (policy == OptionalPolicy::kWithNull && IsTooDeep(prng)) continue;
fields.push_back(i);
}
if (fields.empty()) { // This can happen if all fields are unset.
@@ -1325,13 +1347,16 @@
return descriptor->field_count() + extensions.size();
}
- static auto GetProtobufFields(const Descriptor* descriptor) {
+ static auto GetProtobufFields(const Descriptor* descriptor,
+ bool include_extensions = true) {
std::vector<const FieldDescriptor*> fields;
fields.reserve(descriptor->field_count());
for (int i = 0; i < descriptor->field_count(); ++i) {
fields.push_back(descriptor->field(i));
}
- descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+ if (include_extensions) {
+ descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
+ }
return fields;
}
@@ -1599,6 +1624,7 @@
field->message_type()),
use_lazy_initialization_);
result.SetPolicy(policy_);
+ result.GetPolicy().IncrementDepth();
return Domain<std::unique_ptr<Message>>(result);
} else {
return Domain<T>(ArbitraryImpl<T>());
@@ -1673,6 +1699,46 @@
/*consider_non_terminating_recursions=*/false);
}
+ // Returns an estimate for the number of sub-fields of a given field to be
+ // used as a heuristic. For simplicity, the recursive sub-fields are ignored.
+ // And for efficiency, proto extensions as well as protos with large number of
+ // fields are ignored.
+ static int64_t EstimatedMutationWeight(const FieldDescriptor* field) {
+ auto descriptor = field->message_type();
+ if (!descriptor) return 1; // non-message fields
+ static absl::flat_hash_map<std::string, int> kFieldCountCache;
+ if (auto it = kFieldCountCache.find(descriptor->full_name());
+ it != kFieldCountCache.end()) {
+ return it->second;
+ }
+ absl::flat_hash_set<decltype(descriptor)> parents;
+ int result = EstimatedMutationWeight(descriptor, parents);
+ kFieldCountCache.insert({descriptor->full_name(), result});
+ return result;
+ }
+
+ template <typename Descriptor>
+ static int EstimatedMutationWeight(
+ const Descriptor* descriptor,
+ absl::flat_hash_set<const Descriptor*>& parents) {
+ int field_count = 0;
+ if (parents.contains(descriptor)) return field_count;
+ parents.insert(descriptor);
+ for (const FieldDescriptor* field :
+ GetProtobufFields(descriptor, /*include_extensions=*/false)) {
+ const auto* child = field->message_type();
+ field_count += child ? EstimatedMutationWeight(child, parents) : 1;
+ constexpr int kMaxFieldCount = 1000;
+ // Avoid expensive estimation operations.
+ if (field_count > kMaxFieldCount) {
+ field_count = kMaxFieldCount;
+ break;
+ }
+ }
+ parents.erase(descriptor);
+ return field_count;
+ }
+
bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,