Avoid reading from files in construction time of the protobuf domain. PiperOrigin-RevId: 792721029
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 3086e5f..159e436 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -344,6 +344,31 @@ return caches_->SetFields(descriptor, GetProtobufFields(descriptor)); } + static const std::vector<const FieldDescriptor*>& GetProtobufFields( + const ProtoDescriptor* descriptor) { + ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); + static absl::NoDestructor<absl::flat_hash_map< + const ProtoDescriptor*, std::vector<const FieldDescriptor*>>> + descriptor_to_fields ABSL_GUARDED_BY(mutex); + { + absl::MutexLock l(&mutex); + auto it = descriptor_to_fields->find(descriptor); + if (it != descriptor_to_fields->end()) return it->second; + } + std::vector<const FieldDescriptor*> fields; + fields.reserve(descriptor->field_count()); + for (int i = 0; i < descriptor->field_count(); ++i) { + fields.push_back(descriptor->field(i)); + } + absl::MutexLock l(&mutex); + if (ShouldEnumerateExtensions(descriptor)) { + descriptor->file()->pool()->FindAllExtensions(descriptor, &fields); + } + auto [it, _] = + descriptor_to_fields->insert({descriptor, std::move(fields)}); + return it->second; + } + private: static bool IsMessageSetFuzzingEnabled() { // TODO(b/413402115): Create protobuf domain API enabling MessageSet fuzzing @@ -377,31 +402,6 @@ return !IsMessageSet(descriptor); } - static const std::vector<const FieldDescriptor*>& GetProtobufFields( - const ProtoDescriptor* descriptor) { - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - static absl::NoDestructor<absl::flat_hash_map< - const ProtoDescriptor*, std::vector<const FieldDescriptor*>>> - descriptor_to_fields ABSL_GUARDED_BY(mutex); - { - absl::MutexLock l(&mutex); - auto it = descriptor_to_fields->find(descriptor); - if (it != descriptor_to_fields->end()) return it->second; - } - std::vector<const FieldDescriptor*> fields; - fields.reserve(descriptor->field_count()); - for (int i = 0; i < descriptor->field_count(); ++i) { - fields.push_back(descriptor->field(i)); - } - absl::MutexLock l(&mutex); - if (ShouldEnumerateExtensions(descriptor)) { - descriptor->file()->pool()->FindAllExtensions(descriptor, &fields); - } - auto [it, _] = - descriptor_to_fields->insert({descriptor, std::move(fields)}); - return it->second; - } - // All caches for the policy that contain cached information about subfields // and can be passed down to the subfield policies recursively. class RecursiveFieldsCaches { @@ -595,7 +595,7 @@ always_set_oneofs_(), uncustomizable_oneofs_(), unset_oneof_fields_(), - fields_cache_(policy_.GetFields(prototype_.Get()->GetDescriptor())) {} + fields_cache_() {} ProtobufDomainUntypedImpl(const ProtobufDomainUntypedImpl& other) : prototype_(other.prototype_), @@ -1490,9 +1490,14 @@ return field; } - auto GetFieldCount() const { return fields_cache_.size(); } + auto GetFieldCount() const { return GetProtobufFields().size(); } const std::vector<const FieldDescriptor*>& GetProtobufFields() const { + if (fields_cache_.empty()) { + absl::MutexLock l(&mutex_); + fields_cache_ = ProtoPolicy<Message>::GetProtobufFields( + prototype_.Get()->GetDescriptor()); + } return fields_cache_; } @@ -2024,7 +2029,7 @@ absl::flat_hash_set<int> always_set_oneofs_; absl::flat_hash_set<int> uncustomizable_oneofs_; absl::flat_hash_set<int> unset_oneof_fields_; - std::vector<const FieldDescriptor*> fields_cache_; + mutable std::vector<const FieldDescriptor*> fields_cache_; }; // Domain for `T` where `T` is a Protobuf message type.