Remove synchronization from the protobuf domain mark the domain objects as thread-unsafe.
PiperOrigin-RevId: 792730547
diff --git a/fuzztest/internal/domains/BUILD b/fuzztest/internal/domains/BUILD
index 6c02773..66c3467 100644
--- a/fuzztest/internal/domains/BUILD
+++ b/fuzztest/internal/domains/BUILD
@@ -167,7 +167,6 @@
hdrs = ["protobuf_domain_impl.h"],
deps = [
":core_domains_impl",
- "@abseil-cpp//absl/base",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/base:no_destructor",
"@abseil-cpp//absl/container:flat_hash_map",
diff --git a/fuzztest/internal/domains/CMakeLists.txt b/fuzztest/internal/domains/CMakeLists.txt
index eac0571..6c44277 100644
--- a/fuzztest/internal/domains/CMakeLists.txt
+++ b/fuzztest/internal/domains/CMakeLists.txt
@@ -171,7 +171,6 @@
"protobuf_domain_impl.h"
DEPS
fuzztest::core_domains_impl
- absl::base
absl::core_headers
absl::no_destructor
absl::flat_hash_map
diff --git a/fuzztest/internal/domains/domain.h b/fuzztest/internal/domains/domain.h
index 801f31e..27cd3c6 100644
--- a/fuzztest/internal/domains/domain.h
+++ b/fuzztest/internal/domains/domain.h
@@ -65,7 +65,9 @@
// `Domain<T>` is the type-erased domain interface.
//
// It can be constructed from any object derived from `DomainBase` that
-// implements the domain methods for the value type `T`.
+// implements the domain methods for the value type `T`. A Domain object is not
+// thread-safe. It's the domain object owner's responsibility to make sure the
+// domain object is not accessed concurrently by multiple threads.
template <typename T>
class Domain {
public:
diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h
index cf0566c..aac4f14 100644
--- a/fuzztest/internal/domains/protobuf_domain_impl.h
+++ b/fuzztest/internal/domains/protobuf_domain_impl.h
@@ -27,7 +27,6 @@
#include <vector>
#include "absl/base/attributes.h"
-#include "absl/base/call_once.h"
#include "absl/base/const_init.h"
#include "absl/base/no_destructor.h"
#include "absl/base/thread_annotations.h"
@@ -349,12 +348,13 @@
const ProtoDescriptor* descriptor) {
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<absl::flat_hash_map<
- const ProtoDescriptor*, std::vector<const FieldDescriptor*>>>
+ const ProtoDescriptor*,
+ std::unique_ptr<std::vector<const FieldDescriptor*>>>>
descriptor_to_fields ABSL_GUARDED_BY(mutex);
{
absl::MutexLock l(&mutex);
auto it = descriptor_to_fields->find(descriptor);
- if (it != descriptor_to_fields->end()) return it->second;
+ if (it != descriptor_to_fields->end()) return *(it->second);
}
std::vector<const FieldDescriptor*> fields;
fields.reserve(descriptor->field_count());
@@ -365,9 +365,10 @@
if (ShouldEnumerateExtensions(descriptor)) {
descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
}
- auto [it, _] =
- descriptor_to_fields->insert({descriptor, std::move(fields)});
- return it->second;
+ auto [it, _] = descriptor_to_fields->insert(
+ {descriptor, std::make_unique<std::vector<const FieldDescriptor*>>(
+ std::move(fields))});
+ return *(it->second);
}
private:
@@ -408,12 +409,10 @@
class RecursiveFieldsCaches {
public:
void SetIsFieldFinitelyRecursive(const FieldDescriptor* field, bool value) {
- absl::MutexLock l(&field_to_is_finitely_recursive_mutex_);
field_to_is_finitely_recursive_.insert({field, value});
}
std::optional<bool> IsFieldFinitelyRecursive(const FieldDescriptor* field) {
- absl::ReaderMutexLock l(&field_to_is_finitely_recursive_mutex_);
auto it = field_to_is_finitely_recursive_.find(field);
return it != field_to_is_finitely_recursive_.end()
? std::optional(it->second)
@@ -422,13 +421,11 @@
void SetIsFieldInfinitelyRecursive(const FieldDescriptor* field,
bool value) {
- absl::MutexLock l(&field_to_is_infinitely_recursive_mutex_);
field_to_is_infinitely_recursive_.insert({field, value});
}
std::optional<bool> IsFieldInfinitelyRecursive(
const FieldDescriptor* field) {
- absl::ReaderMutexLock l(&field_to_is_infinitely_recursive_mutex_);
auto it = field_to_is_infinitely_recursive_.find(field);
return it != field_to_is_infinitely_recursive_.end()
? std::optional(it->second)
@@ -437,32 +434,27 @@
const std::vector<const FieldDescriptor*>* GetFields(
const ProtoDescriptor* descriptor) {
- absl::ReaderMutexLock l(&proto_to_fields_mutex_);
auto it = proto_to_fields_.find(descriptor);
- return it != proto_to_fields_.end() ? &it->second : nullptr;
+ return it != proto_to_fields_.end() ? it->second.get() : nullptr;
}
const std::vector<const FieldDescriptor*>& SetFields(
const ProtoDescriptor* descriptor,
std::vector<const FieldDescriptor*> fields) {
- absl::MutexLock l(&proto_to_fields_mutex_);
- auto [it, _] = proto_to_fields_.insert({descriptor, std::move(fields)});
- return it->second;
+ auto [it, _] = proto_to_fields_.insert(
+ {descriptor, std::make_unique<std::vector<const FieldDescriptor*>>(
+ std::move(fields))});
+ return *it->second;
}
private:
- absl::Mutex field_to_is_finitely_recursive_mutex_;
absl::flat_hash_map<const FieldDescriptor*, bool>
- field_to_is_finitely_recursive_
- ABSL_GUARDED_BY(field_to_is_finitely_recursive_mutex_);
- absl::Mutex field_to_is_infinitely_recursive_mutex_;
+ field_to_is_finitely_recursive_;
absl::flat_hash_map<const FieldDescriptor*, bool>
- field_to_is_infinitely_recursive_
- ABSL_GUARDED_BY(field_to_is_infinitely_recursive_mutex_);
- absl::Mutex proto_to_fields_mutex_;
+ field_to_is_infinitely_recursive_;
absl::flat_hash_map<const ProtoDescriptor*,
- std::vector<const FieldDescriptor*>>
- proto_to_fields_ ABSL_GUARDED_BY(proto_to_fields_mutex_);
+ std::unique_ptr<std::vector<const FieldDescriptor*>>>
+ proto_to_fields_;
};
std::shared_ptr<RecursiveFieldsCaches> caches_ = nullptr;
@@ -601,7 +593,6 @@
ProtobufDomainUntypedImpl(const ProtobufDomainUntypedImpl& other)
: prototype_(other.prototype_),
use_lazy_initialization_(other.use_lazy_initialization_) {
- absl::MutexLock l(&other.mutex_);
domains_ = other.domains_;
policy_ = other.policy_;
customized_fields_ = other.customized_fields_;
@@ -1432,7 +1423,6 @@
"` but the field needs a message of type `",
field->message_type()->full_name(), "`.");
}
- absl::MutexLock l(&self.mutex_);
auto res = self.domains_.try_emplace(field->number(),
std::in_place_type<DomainT>,
std::forward<Inner>(domain));
@@ -1493,11 +1483,11 @@
auto GetFieldCount() const { return GetProtobufFields().size(); }
const std::vector<const FieldDescriptor*>& GetProtobufFields() const {
- absl::call_once(fields_cache_once_, [this] {
+ if (!fields_cache_.has_value()) {
fields_cache_ = ProtoPolicy<Message>::GetProtobufFields(
prototype_.Get()->GetDescriptor());
- });
- return fields_cache_;
+ };
+ return *fields_cache_;
}
static auto GetFieldName(const FieldDescriptor* field) {
@@ -1694,7 +1684,6 @@
using DomainT = decltype(GetDefaultDomainForField<T, is_repeated>(field));
// Do the operation under a lock to prevent race conditions in `const`
// methods.
- absl::MutexLock l(&mutex_);
auto it = domains_.find(field->number());
if (it == domains_.end()) {
it = domains_
@@ -2019,19 +2008,14 @@
PrototypePtr<Message> prototype_;
bool use_lazy_initialization_;
- mutable absl::Mutex mutex_;
- mutable absl::flat_hash_map<int, CopyableAny> domains_
- ABSL_GUARDED_BY(mutex_);
+ mutable absl::flat_hash_map<int, CopyableAny> domains_;
ProtoPolicy<Message> policy_;
absl::flat_hash_set<int> customized_fields_;
absl::flat_hash_set<int> always_set_oneofs_;
absl::flat_hash_set<int> uncustomizable_oneofs_;
absl::flat_hash_set<int> unset_oneof_fields_;
-
- // Never access the field directly, always use `GetProtobufFields()`.
- mutable std::vector<const FieldDescriptor*> fields_cache_;
- mutable absl::once_flag fields_cache_once_;
+ mutable std::optional<std::vector<const FieldDescriptor*>> fields_cache_;
};
// Domain for `T` where `T` is a Protobuf message type.