blob: e452d9614b5df46781cc8bb8989827243d1bdeab [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_
#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "./fuzztest/domain_core.h"
#include "./fuzztest/internal/any.h"
#include "./fuzztest/internal/domains/arbitrary_impl.h"
#include "./fuzztest/internal/domains/container_of_impl.h"
#include "./fuzztest/internal/domains/domain_base.h"
#include "./fuzztest/internal/domains/element_of_impl.h"
#include "./fuzztest/internal/domains/in_range_impl.h"
#include "./fuzztest/internal/domains/map_impl.h"
#include "./fuzztest/internal/domains/optional_of_impl.h"
#include "./fuzztest/internal/logging.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/type_support.h"
namespace google::protobuf {
class EnumDescriptor;
template <typename E>
const EnumDescriptor* GetEnumDescriptor();
} // namespace google::protobuf
namespace fuzztest::internal {
// Sniff the API to get the types we need without naming them directly.
// This allows for a soft dependency on proto without having to #include its
// headers.
template <typename Message>
using ProtobufReflection =
std::remove_pointer_t<decltype(std::declval<Message>().GetReflection())>;
template <typename Message>
using ProtobufDescriptor =
std::remove_pointer_t<decltype(std::declval<Message>().GetDescriptor())>;
template <typename Message>
using ProtobufFieldDescriptor = std::remove_pointer_t<
decltype(std::declval<Message>().GetDescriptor()->field(0))>;
template <typename Message>
using ProtobufOneofDescriptor = std::remove_pointer_t<
decltype(std::declval<Message>().GetDescriptor()->oneof_decl(0))>;
template <typename Message, typename V>
class ProtocolBufferAccess;
#define FUZZTEST_INTERNAL_PROTO_ACCESS_(cpp_type, Camel) \
template <typename Message> \
class ProtocolBufferAccess<Message, cpp_type> { \
public: \
using Reflection = ProtobufReflection<Message>; \
using FieldDescriptor = ProtobufFieldDescriptor<Message>; \
using value_type = cpp_type; \
\
ProtocolBufferAccess(Message* message, const FieldDescriptor* field) \
: message_(message), \
reflection_(message->GetReflection()), \
field_(field) {} \
\
static auto GetField(const Message& message, \
const FieldDescriptor* field) { \
return message.GetReflection()->Get##Camel(message, field); \
} \
void SetField(const cpp_type& value) { \
return reflection_->Set##Camel(message_, field_, value); \
} \
static auto GetRepeatedField(const Message& message, \
const FieldDescriptor* field, int index) { \
return message.GetReflection()->GetRepeated##Camel(message, field, \
index); \
} \
void SetRepeatedField(int index, const cpp_type& value) { \
return reflection_->SetRepeated##Camel(message_, field_, index, value); \
} \
void AddRepeatedField(const cpp_type& value) { \
return reflection_->Add##Camel(message_, field_, value); \
} \
\
private: \
Message* message_; \
Reflection* reflection_; \
const FieldDescriptor* field_; \
}
FUZZTEST_INTERNAL_PROTO_ACCESS_(int32_t, Int32);
FUZZTEST_INTERNAL_PROTO_ACCESS_(uint32_t, UInt32);
FUZZTEST_INTERNAL_PROTO_ACCESS_(int64_t, Int64);
FUZZTEST_INTERNAL_PROTO_ACCESS_(uint64_t, UInt64);
FUZZTEST_INTERNAL_PROTO_ACCESS_(float, Float);
FUZZTEST_INTERNAL_PROTO_ACCESS_(double, Double);
FUZZTEST_INTERNAL_PROTO_ACCESS_(std::string, String);
FUZZTEST_INTERNAL_PROTO_ACCESS_(bool, Bool);
#undef FUZZTEST_INTERNAL_PROTO_ACCESS_
struct ProtoEnumTag;
struct ProtoMessageTag;
// Dynamic to static dispatch visitation.
// It will invoke:
// visitor.VisitSingular<type>(field) // for singular fields
// visitor.VisitRepeated<type>(field) // for repeated fields
// where `type` is:
// - For bool, integrals, floating point and string: their C++ type.
// - For enum: the tag type ProtoEnumTag.
// - For message: the tag type ProtoMessageTag.
template <typename FieldDescriptor, typename Visitor>
auto VisitProtobufField(const FieldDescriptor* field, Visitor visitor) {
if (field->is_repeated()) {
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_BOOL:
return visitor.template VisitRepeated<bool>(field);
case FieldDescriptor::CPPTYPE_INT32:
return visitor.template VisitRepeated<int32_t>(field);
case FieldDescriptor::CPPTYPE_UINT32:
return visitor.template VisitRepeated<uint32_t>(field);
case FieldDescriptor::CPPTYPE_INT64:
return visitor.template VisitRepeated<int64_t>(field);
case FieldDescriptor::CPPTYPE_UINT64:
return visitor.template VisitRepeated<uint64_t>(field);
case FieldDescriptor::CPPTYPE_FLOAT:
return visitor.template VisitRepeated<float>(field);
case FieldDescriptor::CPPTYPE_DOUBLE:
return visitor.template VisitRepeated<double>(field);
case FieldDescriptor::CPPTYPE_STRING:
return visitor.template VisitRepeated<std::string>(field);
case FieldDescriptor::CPPTYPE_ENUM:
return visitor.template VisitRepeated<ProtoEnumTag>(field);
case FieldDescriptor::CPPTYPE_MESSAGE:
return visitor.template VisitRepeated<ProtoMessageTag>(field);
}
} else {
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_BOOL:
return visitor.template VisitSingular<bool>(field);
case FieldDescriptor::CPPTYPE_INT32:
return visitor.template VisitSingular<int32_t>(field);
case FieldDescriptor::CPPTYPE_UINT32:
return visitor.template VisitSingular<uint32_t>(field);
case FieldDescriptor::CPPTYPE_INT64:
return visitor.template VisitSingular<int64_t>(field);
case FieldDescriptor::CPPTYPE_UINT64:
return visitor.template VisitSingular<uint64_t>(field);
case FieldDescriptor::CPPTYPE_FLOAT:
return visitor.template VisitSingular<float>(field);
case FieldDescriptor::CPPTYPE_DOUBLE:
return visitor.template VisitSingular<double>(field);
case FieldDescriptor::CPPTYPE_STRING:
return visitor.template VisitSingular<std::string>(field);
case FieldDescriptor::CPPTYPE_ENUM:
return visitor.template VisitSingular<ProtoEnumTag>(field);
case FieldDescriptor::CPPTYPE_MESSAGE:
return visitor.template VisitSingular<ProtoMessageTag>(field);
}
}
}
template <typename Message>
auto GetProtobufField(const Message* prototype, int number) {
auto* field = prototype->GetDescriptor()->FindFieldByNumber(number);
if (field == nullptr) {
field = prototype->GetReflection()->FindKnownExtensionByNumber(number);
}
return field;
}
template <typename T>
using Predicate = std::function<bool(const T*)>;
template <typename T>
Predicate<T> IncludeAll() {
return [](const T*) { return true; };
}
template <typename T>
Predicate<T> IsOptional() {
return [](const T* field) { return field->is_optional(); };
}
template <typename T>
Predicate<T> IsRepeated() {
return [](const T* field) { return field->is_repeated(); };
}
template <typename T>
Predicate<T> And(Predicate<T> lhs, Predicate<T> rhs) {
return [lhs = std::move(lhs), rhs = std::move(rhs)](const T* field) {
return lhs(field) && rhs(field);
};
}
template <typename T>
std::function<Domain<T>(Domain<T>)> Identity() {
return [](Domain<T> domain) -> Domain<T> { return domain; };
}
template <typename Message>
class ProtoPolicy {
using FieldDescriptor = ProtobufFieldDescriptor<Message>;
using Filter = std::function<bool(const FieldDescriptor*)>;
public:
ProtoPolicy()
: optional_policies_({{.filter = IncludeAll<FieldDescriptor>(),
.value = OptionalPolicy::kWithNull}}) {}
void SetOptionalPolicy(OptionalPolicy optional_policy) {
SetOptionalPolicy(IncludeAll<FieldDescriptor>(), optional_policy);
}
void SetOptionalPolicy(Filter filter, OptionalPolicy optional_policy) {
if (optional_policy == OptionalPolicy::kAlwaysNull) {
max_repeated_fields_sizes_.push_back(
{.filter = And(IsRepeated<FieldDescriptor>(), filter), .value = 0});
} else if (optional_policy == OptionalPolicy::kWithoutNull) {
min_repeated_fields_sizes_.push_back(
{.filter = And(IsRepeated<FieldDescriptor>(), filter), .value = 1});
}
optional_policies_.push_back(
{.filter = std::move(filter), .value = optional_policy});
}
void SetMinRepeatedFieldsSize(int64_t min_size) {
SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), min_size);
}
void SetMinRepeatedFieldsSize(Filter filter, int64_t min_size) {
min_repeated_fields_sizes_.push_back(
{.filter = std::move(filter), .value = min_size});
}
void SetMaxRepeatedFieldsSize(int64_t max_size) {
SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), max_size);
}
void SetMaxRepeatedFieldsSize(Filter filter, int64_t max_size) {
max_repeated_fields_sizes_.push_back(
{.filter = std::move(filter), .value = max_size});
}
OptionalPolicy GetOptionalPolicy(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_optional(),
"GetOptionalPolicy should apply to optional fields only!");
std::optional<OptionalPolicy> result =
GetPolicyValue(optional_policies_, field);
FUZZTEST_INTERNAL_CHECK(result.has_value(), "optional policy is not set!");
return *result;
}
std::optional<int64_t> GetMinRepeatedFieldSize(
const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_repeated(),
"GetMinRepeatedFieldSize should apply to repeated fields only!");
return GetPolicyValue(min_repeated_fields_sizes_, field);
}
std::optional<int64_t> GetMaxRepeatedFieldSize(
const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_repeated(),
"GetMaxRepeatedFieldSize should apply to repeated fields only!");
return GetPolicyValue(max_repeated_fields_sizes_, field);
}
private:
template <typename T>
struct FilterToValue {
Filter filter;
T value;
};
template <typename T>
std::optional<T> GetPolicyValue(
const std::vector<FilterToValue<T>>& filter_to_values,
const FieldDescriptor* field) const {
// Return the policy that is not overwritten.
for (int i = filter_to_values.size() - 1; i >= 0; --i) {
if (!filter_to_values[i].filter(field)) continue;
if constexpr (std::is_same_v<T, Domain<std::unique_ptr<Message>>>) {
absl::BitGen gen;
auto domain = filter_to_values[i].value;
auto obj = domain.GetValue(domain.Init(gen));
auto* descriptor = obj->GetDescriptor();
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
descriptor->full_name() == field->message_type()->full_name(),
"Input domain does not match the expected message type. The "
"domain produced a message of type `",
descriptor->full_name(),
"` but the field needs a message of type `",
field->message_type()->full_name(), "`.");
}
return filter_to_values[i].value;
}
return std::nullopt;
}
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
std::vector<FilterToValue<Domain<cpp>>> domains_for_##Camel##_; \
std::vector<FilterToValue<std::function<Domain<cpp>(Domain<cpp>)>>> \
transformers_for_##Camel##_; \
\
public: \
void SetDefaultDomainFor##Camel##s( \
Domain<MakeDependentType<cpp, Message>> domain) { \
domains_for_##Camel##_.push_back({.filter = IncludeAll<FieldDescriptor>(), \
.value = std::move(domain)}); \
} \
void SetDefaultDomainFor##Camel##s( \
Filter filter, Domain<MakeDependentType<cpp, Message>> domain) { \
domains_for_##Camel##_.push_back( \
{.filter = std::move(filter), .value = std::move(domain)}); \
} \
void SetDomainTransformerFor##Camel##s( \
Filter filter, std::function<Domain<MakeDependentType<cpp, Message>>( \
Domain<MakeDependentType<cpp, Message>>)> \
transformer) { \
transformers_for_##Camel##_.push_back( \
{.filter = std::move(filter), .value = std::move(transformer)}); \
} \
std::optional<Domain<MakeDependentType<cpp, Message>>> \
GetDefaultDomainFor##Camel##s(const FieldDescriptor* field) const { \
return GetPolicyValue(domains_for_##Camel##_, field); \
} \
std::optional<std::function<Domain<MakeDependentType<cpp, Message>>( \
Domain<MakeDependentType<cpp, Message>>)>> \
GetDomainTransformerFor##Camel##s(const FieldDescriptor* field) const { \
return GetPolicyValue(transformers_for_##Camel##_, field); \
}
FUZZTEST_INTERNAL_POLICY_MEMBERS(Bool, bool)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Int32, int32_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt32, uint32_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Int64, int64_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt64, uint64_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Float, float)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Double, double)
FUZZTEST_INTERNAL_POLICY_MEMBERS(String, std::string)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Enum, int)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Protobuf, std::unique_ptr<Message>)
};
template <typename Prototype>
class PrototypePtr {
public:
PrototypePtr(std::function<const Prototype*()> prototype_factory)
: prototype_factory_(std::move(prototype_factory)), prototype_(nullptr) {}
PrototypePtr(const Prototype* prototype)
: prototype_factory_(), prototype_(prototype) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(prototype != nullptr,
"Prototype should not be nullptr");
}
PrototypePtr& operator=(const PrototypePtr<Prototype>& other) = default;
PrototypePtr(const PrototypePtr<Prototype>& other) = default;
const Prototype* Get() const {
if (!prototype_) prototype_ = prototype_factory_();
return prototype_;
}
private:
std::function<const Prototype*()> prototype_factory_;
mutable const Prototype* prototype_;
};
// Domain for std::unique_ptr<Message>, where the prototype is accepted as a
// constructor argument.
template <typename Message>
class ProtobufDomainUntypedImpl
: public DomainBase<ProtobufDomainUntypedImpl<Message>,
std::unique_ptr<Message>,
absl::flat_hash_map<int, GenericDomainCorpusType>> {
using Descriptor = ProtobufDescriptor<Message>;
using FieldDescriptor = ProtobufFieldDescriptor<Message>;
using OneofDescriptor = ProtobufOneofDescriptor<Message>;
public:
using typename ProtobufDomainUntypedImpl::DomainBase::corpus_type;
using typename ProtobufDomainUntypedImpl::DomainBase::value_type;
explicit ProtobufDomainUntypedImpl(PrototypePtr<Message> prototype,
bool use_lazy_initialization)
: prototype_(std::move(prototype)),
use_lazy_initialization_(use_lazy_initialization),
policy_(),
customized_fields_(),
always_set_oneofs_(),
uncustomizable_oneofs_(),
oneof_fields_policies_() {}
ProtobufDomainUntypedImpl(const ProtobufDomainUntypedImpl& other)
: prototype_(other.prototype_),
use_lazy_initialization_(other.use_lazy_initialization_) {
absl::MutexLock l(&other.mutex_);
domains_ = other.domains_;
policy_ = other.policy_;
customized_fields_ = other.customized_fields_;
always_set_oneofs_ = other.always_set_oneofs_;
uncustomizable_oneofs_ = other.uncustomizable_oneofs_;
oneof_fields_policies_ = other.oneof_fields_policies_;
}
template <typename T>
static void InitializeFieldValue(absl::BitGenRef prng,
const ProtobufDomainUntypedImpl& self,
const FieldDescriptor* field,
corpus_type& val) {
auto& domain = self.GetSubDomain<T, false>(field);
val[field->number()] = domain.Init(prng);
if (auto* oneof = field->containing_oneof()) {
// Clear the other parts of the oneof. They are unnecessary to
// have and mutating them would have no effect.
for (int i = 0; i < oneof->field_count(); ++i) {
if (i != field->index_in_oneof()) {
val.erase(oneof->field(i)->number());
}
}
}
}
struct InitializeVisitor {
absl::BitGenRef prng;
ProtobufDomainUntypedImpl& self;
corpus_type& val;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
InitializeFieldValue<T>(prng, self, field, val);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
val[field->number()] = domain.Init(prng);
}
};
template <typename OneofDescriptor>
int SelectAFieldIndexInOneof(const OneofDescriptor* oneof,
absl::BitGenRef prng, bool non_recursive_only) {
std::vector<int> fields;
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (non_recursive_only && IsFieldRecursive(oneof->field(i))) continue;
fields.push_back(i);
}
if (fields.empty()) { // This can happen if all fields are unset.
return -1;
}
uint64_t selected =
absl::Uniform(absl::IntervalClosedOpen, prng, size_t{0}, fields.size());
return oneof->field(fields[selected])->index();
}
void SetOneofFieldsPoliciesToWithoutNullWhereNeeded(
const ProtobufDescriptor<Message>* descriptor) {
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
auto* oneof = descriptor->oneof_decl(i);
if (!always_set_oneofs_.contains(oneof->index())) continue;
for (int j = 0; j < oneof->field_count(); ++j) {
if (GetOneofFieldPolicy(oneof->field(j)) == OptionalPolicy::kWithNull) {
SetOneofFieldPolicy(oneof->field(j), OptionalPolicy::kWithoutNull);
}
}
}
}
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
!customized_fields_.empty() || !IsNonTerminatingRecursive(),
"Cannot set recursive fields by default.");
const auto* descriptor = prototype_.Get()->GetDescriptor();
SetOneofFieldsPoliciesToWithoutNullWhereNeeded(descriptor);
corpus_type val;
absl::flat_hash_map<int, int> oneof_to_field;
// TODO(b/241124202): Use a valid proto with minimum size.
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (auto* oneof = field->containing_oneof()) {
if (!oneof_to_field.contains(oneof->index())) {
oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
oneof, prng,
/*non_recursive_only=*/customized_fields_.empty());
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!IsRequired(field) && customized_fields_.empty() &&
IsFieldRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates the
// assumption on domain Init. However, such cases should be extremely
// rare and breaking the assumption would not have severe consequences.
continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
return val;
}
struct MutateVisitor {
absl::BitGenRef prng;
bool only_shrink;
ProtobufDomainUntypedImpl& self;
corpus_type& val;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
auto it = val.find(field->number());
const bool is_present = it != val.end();
if (is_present) {
// Mutate the element
domain.Mutate(it->second, prng, only_shrink);
return;
}
// Add the element
if (!only_shrink) {
InitializeFieldValue<T>(prng, self, field, val);
}
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
auto it = val.find(field->number());
const bool is_present = it != val.end();
if (!is_present) {
if (!only_shrink) {
val[field->number()] = domain.Init(prng);
}
} else if (field->is_map()) {
// field of synthetic messages of the form:
//
// message {
// optional key_type key = 1;
// optional value_type value = 2;
// }
//
// The generic code is not doing dedup and it would only happen some
// time after GetValue when the map field is synchronized between
// reflection and codegen. Let's do it eagerly to drop dead entries so
// that we don't keep mutating them later.
//
// TODO(b/231212420): Improve mutation to not add duplicate keys on the
// first place. The current hack is very inefficient.
// Switch the inner domain for maps to use flat_hash_map instead.
corpus_type corpus_copy;
auto& copy = corpus_copy[field->number()] = it->second;
domain.Mutate(copy, prng, only_shrink);
auto v = self.GetValue(corpus_copy);
// We need to roundtrip through serialization to really dedup. The
// reflection API alone doesn't cut it.
v->ParsePartialFromString(v->SerializePartialAsString());
if (v->GetReflection()->FieldSize(*v, field) ==
domain.GetValue(copy).size()) {
// The number of entries is the same, so accept the change.
it->second = std::move(copy);
}
} else {
domain.Mutate(it->second, prng, only_shrink);
}
}
};
uint64_t CountNumberOfFields(const corpus_type& val) {
uint64_t total_weight = 0;
auto* descriptor = prototype_.Get()->GetDescriptor();
if (descriptor->field_count() == 0) return total_weight;
for (int i = 0; i < descriptor->field_count(); ++i) {
FieldDescriptor* field = descriptor->field(i);
if (field->containing_oneof() &&
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
++total_weight;
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (field->is_repeated()) {
total_weight +=
GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
val_it->second);
} else {
total_weight +=
GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
val_it->second);
}
}
}
return total_weight;
}
uint64_t MutateSelectedField(corpus_type& val, absl::BitGenRef prng,
bool only_shrink,
uint64_t selected_field_index) {
uint64_t field_counter = 0;
auto* descriptor = prototype_.Get()->GetDescriptor();
if (descriptor->field_count() == 0) return field_counter;
for (int i = 0; i < descriptor->field_count(); ++i) {
FieldDescriptor* field = descriptor->field(i);
if (field->containing_oneof() &&
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(field, MutateVisitor{prng, only_shrink, *this, val});
return field_counter;
}
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (field->is_repeated()) {
field_counter +=
GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
val_it->second, prng, only_shrink,
selected_field_index - field_counter);
} else {
field_counter +=
GetSubDomain<ProtoMessageTag, false>(field).MutateSelectedField(
val_it->second, prng, only_shrink,
selected_field_index - field_counter);
}
}
if (field_counter >= selected_field_index) return field_counter;
}
return field_counter;
}
void Mutate(corpus_type& val, absl::BitGenRef prng, bool only_shrink) {
auto* descriptor = prototype_.Get()->GetDescriptor();
if (descriptor->field_count() == 0) return;
// TODO(JunyangShao): Maybe make CountNumberOfFields static.
uint64_t total_weight = CountNumberOfFields(val);
uint64_t selected_weight = absl::Uniform(absl::IntervalClosedClosed, prng,
uint64_t{1}, total_weight);
MutateSelectedField(val, prng, only_shrink, selected_weight);
}
struct GetValueVisitor {
Message& message;
const ProtobufDomainUntypedImpl& self;
const GenericDomainCorpusType& data;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
auto value = domain.GetValue(data);
if (!value.has_value()) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!field->is_required(), "required field '",
std::string(field->name()), "' cannot have null values.");
message.GetReflection()->ClearField(&message, field);
return;
}
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
message.GetReflection()->SetAllocatedMessage(&message, value->release(),
field);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
message.GetReflection()->SetEnumValue(&message, field, *value);
} else {
ProtocolBufferAccess<Message, T>(&message, field).SetField(*value);
}
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
for (auto& v : domain.GetValue(data)) {
message.GetReflection()->AddAllocatedMessage(&message, field,
v.release());
}
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
for (const auto& v : domain.GetValue(data)) {
message.GetReflection()->AddEnumValue(&message, field, v);
}
} else {
for (const auto& v : domain.GetValue(data)) {
ProtocolBufferAccess<Message, T>(&message, field).AddRepeatedField(v);
}
}
}
};
value_type GetValue(const corpus_type& value) const {
value_type out(prototype_.Get()->New());
for (auto& [number, data] : value) {
auto* field = GetProtobufField(prototype_.Get(), number);
VisitProtobufField(field, GetValueVisitor{*out, *this, data});
}
return out;
}
std::optional<corpus_type> FromValue(const value_type& value) const {
return FromValue(*value);
}
struct FromValueVisitor {
const Message& message;
corpus_type& out;
const ProtobufDomainUntypedImpl& self;
// TODO(sbenzaquen): We might want to try avoid these copies. On the other hand,
// FromValue is not called much so it might be ok.
template <typename T>
bool VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
value_type_t<std::decay_t<decltype(domain)>> inner_value;
auto* reflection = message.GetReflection();
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
const auto& child = reflection->GetMessage(message, field);
inner_value = std::unique_ptr<Message>(child.New());
(*inner_value)->CopyFrom(child);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
inner_value = reflection->GetEnum(message, field)->number();
} else {
inner_value =
ProtocolBufferAccess<Message, T>::GetField(message, field);
}
auto inner = domain.FromValue(inner_value);
if (!inner) return false;
out[field->number()] = *std::move(inner);
return true;
}
template <typename T>
bool VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
value_type_t<std::decay_t<decltype(domain)>> inner_value;
auto* reflection = message.GetReflection();
const int size = reflection->FieldSize(message, field);
for (int i = 0; i < size; ++i) {
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
const auto& child = reflection->GetRepeatedMessage(message, field, i);
auto* copy = child.New();
copy->CopyFrom(child);
inner_value.emplace_back(copy);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
inner_value.push_back(
reflection->GetRepeatedEnum(message, field, i)->number());
} else {
inner_value.push_back(
ProtocolBufferAccess<Message, T>::GetRepeatedField(message, field,
i));
}
}
auto inner = domain.FromValue(inner_value);
if (!inner) return false;
out[field->number()] = *std::move(inner);
return true;
}
};
std::optional<corpus_type> FromValue(const Message& value) const {
corpus_type ret;
auto* reflection = value.GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(value, &fields);
for (auto field : fields) {
if (!VisitProtobufField(field, FromValueVisitor{value, ret, *this}))
return std::nullopt;
}
return ret;
}
struct ParseVisitor {
const ProtobufDomainUntypedImpl& self;
const IRObject& obj;
std::optional<GenericDomainCorpusType>& out;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
out = self.GetSubDomain<T, false>(field).ParseCorpus(obj);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
out = self.GetSubDomain<T, true>(field).ParseCorpus(obj);
}
};
std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
corpus_type out;
auto subs = obj.Subs();
if (!subs) return std::nullopt;
absl::flat_hash_set<int> present_fields;
for (const auto& sub : *subs) {
auto pair_subs = sub.Subs();
if (!pair_subs || pair_subs->size() != 2) return std::nullopt;
auto number = (*pair_subs)[0].GetScalar<int>();
if (!number) return std::nullopt;
auto* field = GetProtobufField(prototype_.Get(), *number);
if (!field) return std::nullopt;
present_fields.insert(field->number());
std::optional<GenericDomainCorpusType> inner_parsed;
VisitProtobufField(field,
ParseVisitor{*this, (*pair_subs)[1], inner_parsed});
if (!inner_parsed) return std::nullopt;
out[*number] = *std::move(inner_parsed);
}
for (int field_index = 0;
field_index < prototype_.Get()->GetDescriptor()->field_count();
++field_index) {
const FieldDescriptor* field =
prototype_.Get()->GetDescriptor()->field(field_index);
if (present_fields.contains(field->number())) continue;
std::optional<GenericDomainCorpusType> inner_parsed;
IRObject unset_value;
if (field->is_repeated()) {
unset_value = IRObject(std::vector<IRObject>{});
} else {
unset_value = IRObject(std::vector<IRObject>{IRObject(0)});
}
VisitProtobufField(field, ParseVisitor{*this, unset_value, inner_parsed});
if (!inner_parsed) return std::nullopt;
}
return out;
}
struct SerializeVisitor {
const ProtobufDomainUntypedImpl& self;
const GenericDomainCorpusType& corpus_value;
IRObject& out;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
out = self.GetSubDomain<T, false>(field).SerializeCorpus(corpus_value);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
out = self.GetSubDomain<T, true>(field).SerializeCorpus(corpus_value);
}
};
IRObject SerializeCorpus(const corpus_type& v) const {
IRObject out;
auto& subs = out.MutableSubs();
for (auto& [number, inner] : v) {
auto* field = GetProtobufField(prototype_.Get(), number);
FUZZTEST_INTERNAL_CHECK(field, "Field not found by number: ", number);
IRObject& pair = subs.emplace_back();
auto& pair_subs = pair.MutableSubs();
pair_subs.emplace_back(number);
VisitProtobufField(
field, SerializeVisitor{*this, inner, pair_subs.emplace_back()});
}
return out;
}
struct ValidateVisitor {
const ProtobufDomainUntypedImpl& self;
// nullopt indicates that the field is not set.
const std::optional<GenericDomainCorpusType>& corpus_value;
bool& out;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
const GenericDomainCorpusType value =
corpus_value.has_value()
? *corpus_value
: GetUnsetCorpusValue<T, /*is_repeated=*/false>(field);
out = self.GetSubDomain<T, /*is_repeated=*/false>(field)
.ValidateCorpusValue(value);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
const GenericDomainCorpusType value =
corpus_value.has_value()
? *corpus_value
: GetUnsetCorpusValue<T, /*is_repeated=*/true>(field);
out =
self.GetSubDomain<T, /*is_repeated=*/true>(field).ValidateCorpusValue(
value);
}
private:
template <typename T, bool is_repeated>
GenericDomainCorpusType GetUnsetCorpusValue(const FieldDescriptor* field) {
IRObject unset_value;
if constexpr (is_repeated) {
unset_value = IRObject(std::vector<IRObject>{});
} else {
unset_value = IRObject(std::vector<IRObject>{IRObject(0)});
}
std::optional<GenericDomainCorpusType> result =
self.GetSubDomain<T, is_repeated>(field).ParseCorpus(unset_value);
FUZZTEST_INTERNAL_CHECK(result.has_value(),
"Invalid unset value for field.");
return *result;
}
};
bool ValidateCorpusValue(const corpus_type& corpus_value) const {
for (int field_index = 0;
field_index < prototype_.Get()->GetDescriptor()->field_count();
++field_index) {
const FieldDescriptor* field =
prototype_.Get()->GetDescriptor()->field(field_index);
auto field_number_value = corpus_value.find(field->number());
auto inner_corpus_value = (field_number_value != corpus_value.end())
? std::optional(field_number_value->second)
: std::nullopt;
bool result;
VisitProtobufField(field,
ValidateVisitor{*this, inner_corpus_value, result});
if (!result) return false;
}
return true;
}
auto GetPrinter() const { return ProtobufPrinter{}; }
template <typename Inner>
struct WithFieldVisitor {
Inner domain;
ProtobufDomainUntypedImpl& self;
template <bool is_repeated, typename T>
Descriptor* GetDescriptor(const T& val) const {
auto v = domain.GetValue(val);
if constexpr (is_repeated) {
if (v.empty()) return nullptr;
return v[0]->GetDescriptor();
} else {
if (!v.has_value()) return nullptr;
return (*v)->GetDescriptor();
}
}
template <typename T, typename DomainT, bool is_repeated>
void ApplyDomain(const FieldDescriptor* field) {
if constexpr (!std::is_constructible_v<DomainT, Inner>) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
(std::is_constructible_v<DomainT, Inner>),
"Input domain does not match field `", field->full_name(),
"` type.");
} else {
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
// Verify that the type matches.
absl::BitGen gen;
std::string full_name;
constexpr int kMaxTry = 10;
auto val = domain.Init(gen);
auto descriptor = GetDescriptor<is_repeated>(val);
for (int i = 0; !descriptor && i < kMaxTry; ++i) {
domain.Mutate(val, gen, /*only_shrink=*/false);
descriptor = GetDescriptor<is_repeated>(val);
}
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!descriptor ||
descriptor->full_name() == field->message_type()->full_name(),
"Input domain does not match the expected message type. The "
"domain produced a message of type `",
descriptor->full_name(),
"` but the field needs a message of type `",
field->message_type()->full_name(), "`.");
}
absl::MutexLock l(&self.mutex_);
auto res = self.domains_.try_emplace(field->number(),
std::in_place_type<DomainT>,
std::forward<Inner>(domain));
FUZZTEST_INTERNAL_CHECK_PRECONDITION(res.second, "Domain for field `",
field->full_name(),
"` has been set multiple times.");
}
}
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
using DomainT = decltype(self.GetDefaultDomainForField<T, false>(field));
ApplyDomain<T, DomainT, /*is_repeated=*/false>(field);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
using DomainT = decltype(self.GetDefaultDomainForField<T, true>(field));
ApplyDomain<T, DomainT, /*is_repeated=*/true>(field);
}
};
template <typename Inner>
void WithField(absl::string_view field_name, Inner&& domain) {
auto* field = GetField(field_name);
VisitProtobufField(
field, WithFieldVisitor<Inner&&>{std::forward<Inner>(domain), *this});
customized_fields_.insert(field->index());
}
const FieldDescriptor* GetField(absl::string_view field_name) const {
auto* field = prototype_.Get()->GetDescriptor()->FindFieldByName(
std::string(field_name));
FUZZTEST_INTERNAL_CHECK_PRECONDITION(field != nullptr,
"Invalid field name '",
std::string(field_name), "'.");
return field;
}
void WithOneofField(absl::string_view field_name, OptionalPolicy policy) {
const FieldDescriptor* field = GetField(field_name);
if (!field->containing_oneof()) return;
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
policy != OptionalPolicy::kWithoutNull ||
field->containing_oneof()->field_count() <= 1,
"Cannot always set oneof field ", field_name,
" (try using WithOneofAlwaysSet).");
if (policy == OptionalPolicy::kAlwaysNull) {
SetOneofFieldPolicy(field, policy);
}
}
void WithOneofFieldWithoutNullnessConfiguration(
absl::string_view field_name) {
const FieldDescriptor* field = GetField(field_name);
auto* oneof = field->containing_oneof();
if (!oneof) return;
uncustomizable_oneofs_.insert(oneof->index());
if (always_set_oneofs_.contains(oneof->index())) {
SetOneofFieldPolicy(field, OptionalPolicy::kWithoutNull);
}
}
void WithOneofAlwaysSet(absl::string_view oneof_name) {
const std::string name(oneof_name);
auto* oneof = prototype_.Get()->GetDescriptor()->FindOneofByName(name);
FUZZTEST_INTERNAL_CHECK_PRECONDITION(oneof != nullptr,
"Invalid oneof name '", name, "'.");
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!always_set_oneofs_.contains(oneof->index()), "oneof '", name,
"' is AlwaysSet before.");
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!uncustomizable_oneofs_.contains(oneof->index()),
"WithOneofAlwaysSet(\"", name,
"\") should be called before customizing sub-fields.");
always_set_oneofs_.insert(oneof->index());
}
bool IsOneofAlwaysSet(int oneof_index) {
return always_set_oneofs_.contains(oneof_index);
}
struct WithFieldNullnessVisitor {
ProtobufDomainUntypedImpl& self;
OptionalPolicy policy;
template <typename T>
auto VisitSingular(const FieldDescriptor* field) {
auto inner_domain =
self.GetBaseDomainForFieldType<T>(field, /*use_policy=*/true);
auto domain = self.GetOuterDomainForField</*is_repeated=*/false>(
field, std::move(inner_domain));
if (policy == OptionalPolicy::kAlwaysNull) {
domain.SetAlwaysNull();
} else if (policy == OptionalPolicy::kWithoutNull) {
domain.SetWithoutNull();
}
self.WithField(field->name(), domain);
}
template <typename T>
auto VisitRepeated(const FieldDescriptor* field) {
auto inner_domain =
self.GetBaseDomainForFieldType<T>(field, /*use_policy=*/true);
auto domain = self.GetOuterDomainForField</*is_repeated=*/true>(
field, std::move(inner_domain));
if (policy == OptionalPolicy::kAlwaysNull) {
domain.WithMaxSize(0);
} else if (policy == OptionalPolicy::kWithoutNull) {
domain.WithMinSize(1);
}
self.WithField(field->name(), domain);
}
};
void WithFieldNullness(absl::string_view field_name, OptionalPolicy policy) {
const FieldDescriptor* field = GetField(field_name);
WithOneofField(field_name, policy);
VisitProtobufField(field, WithFieldNullnessVisitor{*this, policy});
}
struct WithRepeatedFieldSizeVisitor {
ProtobufDomainUntypedImpl& self;
std::optional<int64_t> min_size;
std::optional<int64_t> max_size;
template <typename T>
auto VisitSingular(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
false,
"Customizing repeated field size is not applicable to non-repeated "
"field ",
field->name(), ".");
}
template <typename T>
auto VisitRepeated(const FieldDescriptor* field) {
auto inner_domain =
self.GetBaseDomainForFieldType<T>(field, /*use_policy=*/true);
auto domain = self.GetOuterDomainForField</*is_repeated=*/true>(
field, std::move(inner_domain));
if (min_size.has_value()) {
domain.WithMinSize(*min_size);
}
if (max_size.has_value()) {
domain.WithMaxSize(*max_size);
}
self.WithField(field->name(), domain);
}
};
void WithRepeatedFieldSize(absl::string_view field_name,
std::optional<int64_t> min_size,
std::optional<int64_t> max_size) {
const FieldDescriptor* field = GetField(field_name);
VisitProtobufField(field,
WithRepeatedFieldSizeVisitor{*this, min_size, max_size});
}
void SetPolicy(ProtoPolicy<Message> policy) {
CheckIfPolicyCanBeUpdated();
policy_ = policy;
}
ProtoPolicy<Message>& GetPolicy() {
CheckIfPolicyCanBeUpdated();
return policy_;
}
template <typename T>
auto GetFieldTypeDefaultDomain(absl::string_view field_name) const {
auto* field = GetField(field_name);
return GetBaseDomainForFieldType<T>(field, /*use_policy=*/true);
}
template <bool is_repeated, typename Inner>
auto GetOuterDomainForField(const FieldDescriptor* field, Inner domain,
bool use_policy = true) const {
if constexpr (is_repeated) {
return ModifyDomainForRepeatedFieldRule(
std::move(domain),
use_policy ? policy_.GetMinRepeatedFieldSize(field) : std::nullopt,
use_policy ? policy_.GetMaxRepeatedFieldSize(field) : std::nullopt);
} else if (IsRequired(field)) {
return ModifyDomainForRequiredFieldRule(std::move(domain));
} else {
return ModifyDomainForOptionalFieldRule(
std::move(domain), use_policy ? policy_.GetOptionalPolicy(field)
: OptionalPolicy::kWithNull);
}
}
void SetOneofFieldPolicy(const FieldDescriptor* field,
OptionalPolicy policy) {
oneof_fields_policies_.insert({field->index(), policy});
}
OptionalPolicy GetOneofFieldPolicy(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->containing_oneof(),
"GetOneofFieldPolicy should apply to oneof fields only! ",
field->name());
auto result = oneof_fields_policies_.find(field->index());
if (result != oneof_fields_policies_.end()) {
return result->second;
}
return policy_.GetOptionalPolicy(field);
}
private:
void CheckIfPolicyCanBeUpdated() const {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
customized_fields_.empty(),
"All singular modifiers (i.e., .With_Field_()) should come after "
"plural modifiers (i.e., .With_Fields_()). Consider reordering .With_ "
"modifiers.");
}
// Get the existing domain for `field`, if exists.
// Otherwise, create the appropriate `Arbitrary<>` domain for the field and
// return it.
template <typename T, bool is_repeated>
auto& GetSubDomain(const FieldDescriptor* field) const {
using DomainT = decltype(GetDefaultDomainForField<T, is_repeated>(field));
// Do the operation under a lock to prevent race conditions in `const`
// methods.
absl::MutexLock l(&mutex_);
auto it = domains_.find(field->number());
if (it == domains_.end()) {
it = domains_
.try_emplace(field->number(), std::in_place_type<DomainT>,
GetDomainForField<T, is_repeated>(field))
.first;
}
return it->second.template GetAs<DomainT>();
}
// Simple wrapper that converts a Domain<T> into a Domain<vector<T>>.
template <typename T>
static auto ModifyDomainForRepeatedFieldRule(
const Domain<T>& d, std::optional<int64_t> min_size,
std::optional<int64_t> max_size) {
auto result = ContainerOfImpl<std::vector<T>, Domain<T>>(d);
if (min_size.has_value()) {
result.WithMinSize(*min_size);
}
if (max_size.has_value()) {
result.WithMaxSize(*max_size);
}
return result;
}
template <typename T>
static auto ModifyDomainForOptionalFieldRule(const Domain<T>& d,
OptionalPolicy optional_policy) {
auto result = OptionalOfImpl<std::optional<T>, Domain<T>>(d);
if (optional_policy == OptionalPolicy::kWithoutNull) {
result.SetWithoutNull();
} else if (optional_policy == OptionalPolicy::kAlwaysNull) {
result.SetAlwaysNull();
}
return result;
}
template <typename T>
static auto ModifyDomainForRequiredFieldRule(const Domain<T>& d) {
return OptionalOfImpl<std::optional<T>, Domain<T>>(d).SetWithoutNull();
}
// Returns the default "base domain" for a `field` solely based on its type
// (i.e., int32/string), but ignoring the field rule (i.e., the
// repeated/optional/required specifiers).
template <typename T>
auto GetBaseDefaultDomainForFieldType(const FieldDescriptor* field) const {
if constexpr (std::is_same_v<T, std::string>) {
if (field->type() == FieldDescriptor::TYPE_STRING) {
// Can only use UTF-8. For now, simplify as just ASCII.
return Domain<T>(AsciiString());
}
}
if constexpr (std::is_same_v<T, ProtoEnumTag>) {
// For enums, build the list of valid labels.
auto* e = field->enum_type();
std::vector<int> values;
values.reserve(e->value_count());
for (int i = 0; i < e->value_count(); ++i) {
values.push_back(e->value(i)->number());
}
// Delay instantiation. The Domain class is not fully defined at this
// point yet, and neither is ElementOfImpl.
using LazyInt = MakeDependentType<int, T>;
return Domain<LazyInt>(ElementOfImpl<LazyInt>(std::move(values)));
} else if constexpr (std::is_same_v<T, ProtoMessageTag>) {
auto result = ProtobufDomainUntypedImpl(
prototype_.Get()->GetReflection()->GetMessageFactory()->GetPrototype(
field->message_type()),
use_lazy_initialization_);
result.SetPolicy(policy_);
return Domain<std::unique_ptr<Message>>(result);
} else {
return Domain<T>(ArbitraryImpl<T>());
}
}
template <typename T>
auto GetBaseDomainForFieldType(const FieldDescriptor* field,
bool use_policy) const {
auto default_domain = GetBaseDefaultDomainForFieldType<T>(field);
if (!use_policy) {
return default_domain;
}
#define FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Camel, type) \
if constexpr (std::is_same_v<T, type>) { \
auto default_domain_in_policy = \
policy_.GetDefaultDomainFor##Camel##s(field); \
if (default_domain_in_policy.has_value()) { \
default_domain = std::move(*default_domain_in_policy); \
} \
auto transformer_in_policy = \
policy_.GetDomainTransformerFor##Camel##s(field); \
if (transformer_in_policy.has_value()) { \
default_domain = (*transformer_in_policy)(std::move(default_domain)); \
} \
return default_domain; \
}
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Bool, bool)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int32, int32_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt32, uint32_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int64, int64_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt64, uint64_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Float, float)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Double, double)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(String, std::string)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Enum, ProtoEnumTag)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Protobuf, ProtoMessageTag)
return GetBaseDefaultDomainForFieldType<T>(field);
}
template <typename T, bool is_repeated>
auto GetDomainForField(const FieldDescriptor* field,
bool use_policy = true) const {
auto base_domain = GetBaseDomainForFieldType<T>(field, use_policy);
using field_cpptype = value_type_t<std::decay_t<decltype(base_domain)>>;
if constexpr (is_repeated) {
return Domain<std::vector<field_cpptype>>(
GetOuterDomainForField<is_repeated>(field, base_domain, use_policy));
} else {
return Domain<std::optional<field_cpptype>>(
GetOuterDomainForField<is_repeated>(field, base_domain, use_policy));
}
}
template <typename T, bool is_repeated>
auto GetDefaultDomainForField(const FieldDescriptor* field) const {
return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false);
}
bool IsNonTerminatingRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, policy_,
/*consider_non_terminating_recursions=*/true);
}
bool IsFieldRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
absl::flat_hash_set<decltype(field->message_type())> parents;
return IsProtoRecursive(field->message_type(), parents, policy_,
/*consider_non_terminating_recursions=*/false);
}
bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
bool consider_non_terminating_recursions) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (consider_non_terminating_recursions) {
is_oneof_recursive =
field_policy != OptionalPolicy::kWithNull && child &&
IsProtoRecursive(child, parents, policy,
consider_non_terminating_recursions);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, policy,
consider_non_terminating_recursions)) {
return true;
}
}
}
return is_oneof_recursive;
}
template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
bool consider_non_terminating_recursions) const {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, policy,
consider_non_terminating_recursions)) {
parents.erase(descriptor);
return true;
}
}
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (field->containing_oneof()) continue;
const auto* child = field->message_type();
if (!child) continue;
if (consider_non_terminating_recursions) {
const bool should_be_set =
IsRequired(field) ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) == OptionalPolicy::kWithoutNull) ||
(field->is_repeated() &&
policy.GetMinRepeatedFieldSize(field).has_value() &&
*policy.GetMinRepeatedFieldSize(field) > 0);
if (!should_be_set) continue;
} else {
const bool can_be_set =
IsRequired(field) ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) ||
(field->is_repeated() &&
(!policy.GetMaxRepeatedFieldSize(field).has_value() ||
*policy.GetMaxRepeatedFieldSize(field) > 0));
if (!can_be_set) continue;
}
if (IsProtoRecursive(child, parents, policy,
consider_non_terminating_recursions)) {
parents.erase(descriptor);
return true;
}
}
parents.erase(descriptor);
return false;
}
bool IsRequired(const FieldDescriptor* field) const {
if (field->containing_oneof() &&
GetOneofFieldPolicy(field) == OptionalPolicy::kWithoutNull) {
return true;
}
return field->is_required() || IsMapValueMessage(field);
}
static bool IsMapValueMessage(const FieldDescriptor* field) {
return field->message_type() &&
field->containing_type()->map_value() == field;
}
PrototypePtr<Message> prototype_;
bool use_lazy_initialization_;
mutable absl::Mutex mutex_;
mutable absl::flat_hash_map<int, CopyableAny> domains_
ABSL_GUARDED_BY(mutex_);
ProtoPolicy<Message> policy_;
absl::flat_hash_set<int> customized_fields_;
absl::flat_hash_set<int> always_set_oneofs_;
absl::flat_hash_set<int> uncustomizable_oneofs_;
absl::flat_hash_map<int, OptionalPolicy> oneof_fields_policies_;
};
// Domain for `T` where `T` is a Protobuf message type.
// It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more
// convenient.
template <typename T,
typename UntypedImpl = ProtobufDomainUntypedImpl<typename T::Message>>
class ProtobufDomainImpl
: public DomainBase<ProtobufDomainImpl<T>, T, corpus_type_t<UntypedImpl>> {
public:
using typename ProtobufDomainImpl::DomainBase::corpus_type;
using typename ProtobufDomainImpl::DomainBase::value_type;
using FieldDescriptor = ProtobufFieldDescriptor<typename T::Message>;
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
return inner_.Init(prng);
}
uint64_t CountNumberOfFields(const corpus_type& val) {
return inner_.CountNumberOfFields(val);
}
uint64_t MutateNumberOfProtoFields(corpus_type& val) {
return inner_.MutateNumberOfProtoFields(val);
}
void Mutate(corpus_type& val, absl::BitGenRef prng, bool only_shrink) {
inner_.Mutate(val, prng, only_shrink);
}
value_type GetValue(const corpus_type& v) const {
auto inner_v = inner_.GetValue(v);
return std::move(static_cast<T&>(*inner_v));
}
std::optional<corpus_type> FromValue(const value_type& value) const {
return inner_.FromValue(value);
}
auto GetPrinter() const { return ProtobufPrinter{}; }
std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
return inner_.ParseCorpus(obj);
}
IRObject SerializeCorpus(const corpus_type& v) const {
return inner_.SerializeCorpus(v);
}
bool ValidateCorpusValue(const corpus_type& corpus_value) const {
return inner_.ValidateCorpusValue(corpus_value);
}
// Provide a conversion to the type that WithMessageField wants.
// Makes it easier on the user.
operator Domain<std::unique_ptr<typename T::Message>>() const {
return inner_;
}
ProtobufDomainImpl&& Self() && { return std::move(*this); }
ProtobufDomainImpl&& WithFieldsAlwaysSet(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(std::move(filter),
OptionalPolicy::kWithoutNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithFieldsUnset(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(std::move(filter),
OptionalPolicy::kAlwaysNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithOptionalFieldsAlwaysSet(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(
And(IsOptional<FieldDescriptor>(), std::move(filter)),
OptionalPolicy::kWithoutNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithOptionalFieldsUnset(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(
And(IsOptional<FieldDescriptor>(), std::move(filter)),
OptionalPolicy::kAlwaysNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsAlwaysSet(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(
And(IsRepeated<FieldDescriptor>(), std::move(filter)),
OptionalPolicy::kWithoutNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsUnset(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(
And(IsRepeated<FieldDescriptor>(), std::move(filter)),
OptionalPolicy::kAlwaysNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsSize(int64_t size) && {
return std::move(*this)
.WithRepeatedFieldsMinSize(size)
.WithRepeatedFieldsMaxSize(size);
}
ProtobufDomainImpl&& WithRepeatedFieldsSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t size) && {
return std::move(*this)
.WithRepeatedFieldsMinSize(filter, size)
.WithRepeatedFieldsMaxSize(filter, size);
}
ProtobufDomainImpl&& WithRepeatedFieldsMinSize(int64_t min_size) && {
inner_.GetPolicy().SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(),
min_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMinSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t min_size) && {
inner_.GetPolicy().SetMinRepeatedFieldsSize(std::move(filter), min_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(),
max_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(std::move(filter), max_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithFieldUnset(absl::string_view field) && {
inner_.WithFieldNullness(field, OptionalPolicy::kAlwaysNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithFieldAlwaysSet(absl::string_view field) && {
inner_.WithFieldNullness(field, OptionalPolicy::kWithoutNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldSize(
absl::string_view field_name, std::optional<int64_t> min_size,
std::optional<int64_t> max_size) && {
inner_.WithRepeatedFieldSize(field_name, min_size, max_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldMinSize(absl::string_view field_name,
int64_t min_size) && {
inner_.WithRepeatedFieldSize(field_name, min_size, std::nullopt);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldMaxSize(absl::string_view field_name,
int64_t max_size) && {
inner_.WithRepeatedFieldSize(field_name, std::nullopt, max_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithOneofAlwaysSet(absl::string_view oneof_name) && {
inner_.WithOneofAlwaysSet(oneof_name);
return std::move(*this);
}
#define FUZZTEST_INTERNAL_WITH_FIELD(Camel, cpp, TAG) \
using Camel##type = MakeDependentType<cpp, T>; \
ProtobufDomainImpl&& With##Camel##Field(absl::string_view field, \
Domain<Camel##type> domain)&& { \
const FieldDescriptor* descriptor = inner_.GetField(field); \
if (descriptor->is_repeated()) { \
inner_.WithField( \
field, inner_.template GetOuterDomainForField</*is_repeated=*/true>( \
descriptor, std::move(domain))); \
} else { \
inner_.WithOneofFieldWithoutNullnessConfiguration(field); \
inner_.WithField( \
field, \
inner_.template GetOuterDomainForField</*is_repeated=*/false>( \
descriptor, std::move(domain))); \
} \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldAlwaysSet( \
absl::string_view field, Domain<Camel##type> domain)&& { \
const FieldDescriptor* descriptor = inner_.GetField(field); \
if (descriptor->is_repeated()) { \
inner_.WithField( \
field, \
SequenceContainerOfImpl<std::vector<Camel##type>, decltype(domain)>( \
std::move(domain)) \
.WithMinSize(1)); \
} else { \
inner_.WithOneofField(field, OptionalPolicy::kWithoutNull); \
inner_.WithField( \
field, OptionalOfImpl<std::optional<Camel##type>, decltype(domain)>( \
std::move(domain)) \
.SetWithoutNull()); \
} \
return std::move(*this); \
} \
/* TODO(b/271123298): Remove the following two methods and replace them with \
WithField(Unset/AlwaysSet) */ \
ProtobufDomainImpl&& With##Camel##FieldUnset(absl::string_view field)&& { \
auto default_domain = \
inner_.template GetFieldTypeDefaultDomain<TAG>(field); \
inner_.WithOneofField(field, OptionalPolicy::kAlwaysNull); \
inner_.WithField( \
field, \
OptionalOfImpl<std::optional<Camel##type>, decltype(default_domain)>( \
std::move(default_domain)) \
.SetAlwaysNull()); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldAlwaysSet( \
absl::string_view field)&& { \
return std::move(*this).With##Camel##FieldAlwaysSet( \
field, inner_.template GetFieldTypeDefaultDomain<TAG>(field)); \
} \
ProtobufDomainImpl&& WithOptional##Camel##Field( \
absl::string_view field, \
Domain<MakeDependentType<std::optional<cpp>, T>> domain)&& { \
FailIfIsOneof(field); \
inner_.WithField(field, std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& WithRepeated##Camel##Field( \
absl::string_view field, \
Domain<MakeDependentType<std::vector<cpp>, T>> domain)&& { \
inner_.WithField(field, std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##Fields(Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
IncludeAll<FieldDescriptor>(), std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##Fields( \
std::function<bool(const FieldDescriptor*)>&& filter, \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s(std::move(filter), \
std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& WithOptional##Camel##Fields( \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
IsOptional<FieldDescriptor>(), std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& WithOptional##Camel##Fields( \
std::function<bool(const FieldDescriptor*)>&& filter, \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
And(IsOptional<FieldDescriptor>(), std::move(filter)), \
std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& WithRepeated##Camel##Fields( \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
IsRepeated<FieldDescriptor>(), std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& WithRepeated##Camel##Fields( \
std::function<bool(const FieldDescriptor*)>&& filter, \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
And(IsRepeated<FieldDescriptor>(), std::move(filter)), \
std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldsTransformed( \
std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \
transformer)&& { \
inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \
IncludeAll<FieldDescriptor>(), std::move(transformer)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldsTransformed( \
std::function<bool(const FieldDescriptor*)>&& filter, \
std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \
transformer)&& { \
inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \
std::move(filter), std::move(transformer)); \
return std::move(*this); \
}
FUZZTEST_INTERNAL_WITH_FIELD(Bool, bool, bool)
FUZZTEST_INTERNAL_WITH_FIELD(Int32, int32_t, int32_t)
FUZZTEST_INTERNAL_WITH_FIELD(UInt32, uint32_t, uint32_t)
FUZZTEST_INTERNAL_WITH_FIELD(Int64, int64_t, int64_t)
FUZZTEST_INTERNAL_WITH_FIELD(UInt64, uint64_t, uint64_t)
FUZZTEST_INTERNAL_WITH_FIELD(Float, float, float)
FUZZTEST_INTERNAL_WITH_FIELD(Double, double, double)
FUZZTEST_INTERNAL_WITH_FIELD(String, std::string, std::string)
FUZZTEST_INTERNAL_WITH_FIELD(Enum, int, ProtoEnumTag)
FUZZTEST_INTERNAL_WITH_FIELD(Protobuf, std::unique_ptr<typename T::Message>,
ProtoMessageTag)
#undef FUZZTEST_INTERNAL_WITH_FIELD
// The following methods automatically cast Domain<Proto> to
// Domain<unique_ptr<Message>>
template <typename Protobuf>
ProtobufDomainImpl&& WithProtobufField(absl::string_view field,
Domain<Protobuf> domain) && {
return std::move(*this).WithProtobufField(
field, ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithProtobufFieldAlwaysSet(absl::string_view field,
Domain<Protobuf> domain) && {
return std::move(*this).WithProtobufFieldAlwaysSet(
field, ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithProtobufFields(Domain<Protobuf> domain) && {
return std::move(*this).WithProtobufFields(
ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithProtobufFields(
std::function<bool(const FieldDescriptor*)>&& filter,
Domain<Protobuf> domain) && {
return std::move(*this).WithProtobufFields(
std::move(filter), ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithOptionalProtobufFields(Domain<Protobuf> domain) && {
return std::move(*this).WithOptionalProtobufFields(
ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithOptionalProtobufFields(
std::function<bool(const FieldDescriptor*)>&& filter,
Domain<Protobuf> domain) && {
return std::move(*this).WithOptionalProtobufFields(
std::move(filter), ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithRepeatedProtobufFields(Domain<Protobuf> domain) && {
return std::move(*this).WithRepeatedProtobufFields(
ToUntypedProtoDomain(std::move(domain)));
}
template <typename Protobuf>
ProtobufDomainImpl&& WithRepeatedProtobufFields(
std::function<bool(const FieldDescriptor*)>&& filter,
Domain<Protobuf> domain) && {
return std::move(*this).WithRepeatedProtobufFields(
std::move(filter), ToUntypedProtoDomain(std::move(domain)));
}
template <typename OptionalProtobufDomain>
ProtobufDomainImpl&& WithOptionalProtobufField(
absl::string_view field, OptionalProtobufDomain domain) && {
return std::move(*this).WithOptionalProtobufField(
field, ToOptionalUntypedProtoDomain(std::move(domain)));
}
template <typename RepeatedProtobufDomain>
ProtobufDomainImpl&& WithRepeatedProtobufField(
absl::string_view field, RepeatedProtobufDomain domain) && {
return std::move(*this).WithRepeatedProtobufField(
field, ToRepeatedUntypedProtoDomain(std::move(domain)));
}
private:
void FailIfIsOneof(absl::string_view field) {
const FieldDescriptor* descriptor = inner_.GetField(field);
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!descriptor->containing_oneof(), "Cannot customize oneof field ", field,
" with WithOptional<Type>Field (try using "
"WithOneofAlwaysSet or WithOptional<Type>Unset).");
}
template <typename Inner>
Domain<std::unique_ptr<typename T::Message>> ToUntypedProtoDomain(
Inner inner_domain) {
return BidiMap(
[](value_type_t<Inner> proto_message)
-> std::unique_ptr<typename T::Message> {
return {std::make_unique<value_type_t<Inner>>(proto_message)};
},
[](const std::unique_ptr<typename T::Message>& proto_message)
-> std::optional<std::tuple<value_type_t<Inner>>> {
return *static_cast<std::add_pointer_t<value_type_t<Inner>>>(
proto_message.get());
},
std::move(inner_domain));
}
template <typename Inner>
Domain<std::optional<std::unique_ptr<typename T::Message>>>
ToOptionalUntypedProtoDomain(Inner inner_domain) {
return BidiMap(
[](value_type_t<Inner> proto_message)
-> std::optional<std::unique_ptr<typename T::Message>> {
if (!proto_message.has_value()) return std::nullopt;
return std::make_unique<
std::remove_reference_t<decltype(*proto_message)>>(
*proto_message);
},
[](const std::optional<std::unique_ptr<typename T::Message>>&
proto_message)
-> std::optional<std::tuple<value_type_t<Inner>>> {
if (!proto_message.has_value()) return std::nullopt;
return *static_cast<
std::add_pointer_t<typename value_type_t<Inner>::value_type>>(
proto_message->get());
},
std::move(inner_domain));
}
template <typename Inner>
Domain<std::vector<std::unique_ptr<typename T::Message>>>
ToRepeatedUntypedProtoDomain(Inner inner_domain) {
return BidiMap(
[](value_type_t<Inner> proto_message)
-> std::vector<std::unique_ptr<typename T::Message>> {
std::vector<std::unique_ptr<typename T::Message>> result;
for (auto& entry : proto_message) {
result.push_back(
std::make_unique<std::remove_const_t<
std::remove_reference_t<decltype(entry)>>>(entry));
}
return result;
},
[](const std::vector<std::unique_ptr<typename T::Message>>&
proto_message)
-> std::optional<std::tuple<value_type_t<Inner>>> {
value_type_t<Inner> result;
for (auto& entry : proto_message) {
result.push_back(
*(static_cast<std::add_pointer_t<
typename value_type_t<Inner>::value_type>>(entry.get())));
}
return result;
},
std::move(inner_domain));
}
UntypedImpl inner_{&T::default_instance(), /*use_lazy_initialization=*/true};
};
template <typename T>
class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_v<T>>>
: public ProtobufDomainImpl<T> {};
template <typename T>
class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_enum_v<T>>>
: public DomainBase<ArbitraryImpl<T>> {
public:
using typename ArbitraryImpl::DomainBase::value_type;
value_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
const int index = absl::Uniform(prng, 0, descriptor()->value_count());
return static_cast<T>(descriptor()->value(index)->number());
}
void Mutate(value_type& val, absl::BitGenRef prng, bool only_shrink) {
if (only_shrink) {
std::vector<int> numbers;
for (int i = 0; i < descriptor()->value_count(); ++i) {
if (int n = descriptor()->value(i)->number(); n < val)
numbers.push_back(n);
}
if (numbers.empty()) return;
size_t idx = absl::Uniform<size_t>(prng, 0, numbers.size());
val = static_cast<T>(numbers[idx]);
return;
} else if (descriptor()->value_count() == 1) {
return;
}
// Make sure Mutate really mutates.
const T prev = val;
do {
val = Init(prng);
} while (val == prev);
}
auto GetPrinter() const {
return ProtobufEnumPrinter<decltype(descriptor())>{descriptor()};
}
bool ValidateCorpusValue(const value_type&) const {
return true; // Any number is fine.
}
private:
auto descriptor() const {
static auto const descriptor_ = google::protobuf::GetEnumDescriptor<T>();
return descriptor_;
}
};
} // namespace fuzztest::internal
#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_