blob: e8fa6400efdc2efc2fd6d6d31b3079f313ec47e8 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_
#define FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "./fuzztest/internal/domain.h"
#include "./fuzztest/internal/logging.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/polymorphic_value.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/type_support.h"
namespace google::protobuf {
class EnumDescriptor;
template <typename E>
const EnumDescriptor* GetEnumDescriptor();
} // namespace google::protobuf
namespace fuzztest::internal {
// Sniff the API to get the types we need without naming them directly.
// This allows for a soft dependency on proto without having to #include its
// headers.
template <typename Message>
using ProtobufReflection =
std::remove_pointer_t<decltype(std::declval<Message>().GetReflection())>;
template <typename Message>
using ProtobufDescriptor =
std::remove_pointer_t<decltype(std::declval<Message>().GetDescriptor())>;
template <typename Message>
using ProtobufFieldDescriptor = std::remove_pointer_t<
decltype(std::declval<Message>().GetDescriptor()->field(0))>;
template <typename Message, typename V>
class ProtocolBufferAccess;
#define FUZZTEST_INTERNAL_PROTO_ACCESS_(cpp_type, Camel) \
template <typename Message> \
class ProtocolBufferAccess<Message, cpp_type> { \
public: \
using Reflection = ProtobufReflection<Message>; \
using FieldDescriptor = ProtobufFieldDescriptor<Message>; \
using value_type = cpp_type; \
\
ProtocolBufferAccess(Message* message, const FieldDescriptor* field) \
: message_(message), \
reflection_(message->GetReflection()), \
field_(field) {} \
\
static auto GetField(const Message& message, \
const FieldDescriptor* field) { \
return message.GetReflection()->Get##Camel(message, field); \
} \
void SetField(const cpp_type& value) { \
return reflection_->Set##Camel(message_, field_, value); \
} \
static auto GetRepeatedField(const Message& message, \
const FieldDescriptor* field, int index) { \
return message.GetReflection()->GetRepeated##Camel(message, field, \
index); \
} \
void SetRepeatedField(int index, const cpp_type& value) { \
return reflection_->SetRepeated##Camel(message_, field_, index, value); \
} \
void AddRepeatedField(const cpp_type& value) { \
return reflection_->Add##Camel(message_, field_, value); \
} \
\
private: \
Message* message_; \
Reflection* reflection_; \
const FieldDescriptor* field_; \
}
FUZZTEST_INTERNAL_PROTO_ACCESS_(int32_t, Int32);
FUZZTEST_INTERNAL_PROTO_ACCESS_(uint32_t, UInt32);
FUZZTEST_INTERNAL_PROTO_ACCESS_(int64_t, Int64);
FUZZTEST_INTERNAL_PROTO_ACCESS_(uint64_t, UInt64);
FUZZTEST_INTERNAL_PROTO_ACCESS_(float, Float);
FUZZTEST_INTERNAL_PROTO_ACCESS_(double, Double);
FUZZTEST_INTERNAL_PROTO_ACCESS_(std::string, String);
FUZZTEST_INTERNAL_PROTO_ACCESS_(bool, Bool);
#undef FUZZTEST_INTERNAL_PROTO_ACCESS_
struct ProtoEnumTag;
struct ProtoMessageTag;
// Dynamic to static dispatch visitation.
// It will invoke:
// visitor.VisitSingular<type>(field) // for singular fields
// visitor.VisitRepeated<type>(field) // for repeated fields
// where `type` is:
// - For bool, integrals, floating point and string: their C++ type.
// - For enum: the tag type ProtoEnumTag.
// - For message: the tag type ProtoMessageTag.
template <typename FieldDescriptor, typename Visitor>
auto VisitProtobufField(const FieldDescriptor* field, Visitor visitor) {
if (field->is_repeated()) {
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_BOOL:
return visitor.template VisitRepeated<bool>(field);
case FieldDescriptor::CPPTYPE_INT32:
return visitor.template VisitRepeated<int32_t>(field);
case FieldDescriptor::CPPTYPE_UINT32:
return visitor.template VisitRepeated<uint32_t>(field);
case FieldDescriptor::CPPTYPE_INT64:
return visitor.template VisitRepeated<int64_t>(field);
case FieldDescriptor::CPPTYPE_UINT64:
return visitor.template VisitRepeated<uint64_t>(field);
case FieldDescriptor::CPPTYPE_FLOAT:
return visitor.template VisitRepeated<float>(field);
case FieldDescriptor::CPPTYPE_DOUBLE:
return visitor.template VisitRepeated<double>(field);
case FieldDescriptor::CPPTYPE_STRING:
return visitor.template VisitRepeated<std::string>(field);
case FieldDescriptor::CPPTYPE_ENUM:
return visitor.template VisitRepeated<ProtoEnumTag>(field);
case FieldDescriptor::CPPTYPE_MESSAGE:
return visitor.template VisitRepeated<ProtoMessageTag>(field);
}
} else {
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_BOOL:
return visitor.template VisitSingular<bool>(field);
case FieldDescriptor::CPPTYPE_INT32:
return visitor.template VisitSingular<int32_t>(field);
case FieldDescriptor::CPPTYPE_UINT32:
return visitor.template VisitSingular<uint32_t>(field);
case FieldDescriptor::CPPTYPE_INT64:
return visitor.template VisitSingular<int64_t>(field);
case FieldDescriptor::CPPTYPE_UINT64:
return visitor.template VisitSingular<uint64_t>(field);
case FieldDescriptor::CPPTYPE_FLOAT:
return visitor.template VisitSingular<float>(field);
case FieldDescriptor::CPPTYPE_DOUBLE:
return visitor.template VisitSingular<double>(field);
case FieldDescriptor::CPPTYPE_STRING:
return visitor.template VisitSingular<std::string>(field);
case FieldDescriptor::CPPTYPE_ENUM:
return visitor.template VisitSingular<ProtoEnumTag>(field);
case FieldDescriptor::CPPTYPE_MESSAGE:
return visitor.template VisitSingular<ProtoMessageTag>(field);
}
}
}
template <typename Message>
auto GetProtobufField(const Message* prototype, int number) {
auto* field = prototype->GetDescriptor()->FindFieldByNumber(number);
if (field == nullptr) {
field = prototype->GetReflection()->FindKnownExtensionByNumber(number);
}
return field;
}
template <typename T>
auto IncludeAll() {
return [](const T* field) { return true; };
}
template <typename T>
std::function<Domain<T>(Domain<T>)> Identity() {
return [](Domain<T> domain) -> Domain<T> { return domain; };
}
template <typename Message>
class ProtoPolicy {
using FieldDescriptor = ProtobufFieldDescriptor<Message>;
using Filter = std::function<bool(const FieldDescriptor*)>;
public:
ProtoPolicy()
: optional_policies_({{.filter = IncludeAll<FieldDescriptor>(),
.value = OptionalPolicy::kWithNull}}) {}
ProtoPolicy GetChildrenPolicy() const {
ProtoPolicy children_policy;
children_policy.optional_policies_ =
CopyRecursivePolicyValues(optional_policies_);
children_policy.min_repeated_fields_sizes_ =
CopyRecursivePolicyValues(min_repeated_fields_sizes_);
children_policy.max_repeated_fields_sizes_ =
CopyRecursivePolicyValues(max_repeated_fields_sizes_);
children_policy.domains_for_Bool_ =
CopyRecursivePolicyValues(domains_for_Bool_);
children_policy.domains_for_Int32_ =
CopyRecursivePolicyValues(domains_for_Int32_);
children_policy.domains_for_UInt32_ =
CopyRecursivePolicyValues(domains_for_UInt32_);
children_policy.domains_for_Int64_ =
CopyRecursivePolicyValues(domains_for_Int64_);
children_policy.domains_for_UInt64_ =
CopyRecursivePolicyValues(domains_for_UInt64_);
children_policy.domains_for_Float_ =
CopyRecursivePolicyValues(domains_for_Float_);
children_policy.domains_for_Double_ =
CopyRecursivePolicyValues(domains_for_Double_);
children_policy.domains_for_String_ =
CopyRecursivePolicyValues(domains_for_String_);
children_policy.domains_for_Enum_ =
CopyRecursivePolicyValues(domains_for_Enum_);
children_policy.domains_for_Protobuf_ =
CopyRecursivePolicyValues(domains_for_Protobuf_);
children_policy.transformers_for_Bool_ =
CopyRecursivePolicyValues(transformers_for_Bool_);
children_policy.transformers_for_Int32_ =
CopyRecursivePolicyValues(transformers_for_Int32_);
children_policy.transformers_for_UInt32_ =
CopyRecursivePolicyValues(transformers_for_UInt32_);
children_policy.transformers_for_Int64_ =
CopyRecursivePolicyValues(transformers_for_Int64_);
children_policy.transformers_for_UInt64_ =
CopyRecursivePolicyValues(transformers_for_UInt64_);
children_policy.transformers_for_Float_ =
CopyRecursivePolicyValues(transformers_for_Float_);
children_policy.transformers_for_Double_ =
CopyRecursivePolicyValues(transformers_for_Double_);
children_policy.transformers_for_String_ =
CopyRecursivePolicyValues(transformers_for_String_);
children_policy.transformers_for_Enum_ =
CopyRecursivePolicyValues(transformers_for_Enum_);
children_policy.transformers_for_Protobuf_ =
CopyRecursivePolicyValues(transformers_for_Protobuf_);
return children_policy;
}
void SetOptionalPolicy(OptionalPolicy optional_policy,
bool is_recursive = true) {
SetOptionalPolicy(IncludeAll<FieldDescriptor>(), optional_policy,
is_recursive);
}
void SetOptionalPolicy(Filter filter, OptionalPolicy optional_policy,
bool is_recursive = true) {
optional_policies_.push_back({.filter = std::move(filter),
.value = optional_policy,
.is_recursive = is_recursive});
}
void SetMinRepeatedFieldsSize(int64_t min_size, bool is_recursive = true) {
SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), min_size,
is_recursive);
}
void SetMinRepeatedFieldsSize(Filter filter, int64_t min_size,
bool is_recursive = true) {
min_repeated_fields_sizes_.push_back({.filter = std::move(filter),
.value = min_size,
.is_recursive = is_recursive});
}
void SetMaxRepeatedFieldsSize(int64_t max_size, bool is_recursive = true) {
SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), max_size,
is_recursive);
}
void SetMaxRepeatedFieldsSize(Filter filter, int64_t max_size,
bool is_recursive = true) {
max_repeated_fields_sizes_.push_back({.filter = std::move(filter),
.value = max_size,
.is_recursive = is_recursive});
}
OptionalPolicy GetOptionalPolicy(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_optional(),
"GetOptionalPolicy should apply to optional fields only!");
std::optional<OptionalPolicy> result =
GetPolicyValue(optional_policies_, field);
FUZZTEST_INTERNAL_CHECK(result.has_value(), "optional policy is not set!");
return *result;
}
std::optional<int64_t> GetMinRepeatedFieldSize(
const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_repeated(),
"GetMinRepeatedFieldSize should apply to repeated fields only!");
return GetPolicyValue(min_repeated_fields_sizes_, field);
}
std::optional<int64_t> GetMaxRepeatedFieldSize(
const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(
field->is_repeated(),
"GetMaxRepeatedFieldSize should apply to repeated fields only!");
return GetPolicyValue(max_repeated_fields_sizes_, field);
}
private:
template <typename T>
struct FilterToValue {
Filter filter;
T value;
bool is_recursive = true;
};
template <typename T>
std::vector<FilterToValue<T>> CopyRecursivePolicyValues(
const std::vector<FilterToValue<T>>& filter_to_values) const {
std::vector<FilterToValue<T>> result;
for (const auto& filter_to_value : filter_to_values) {
if (filter_to_value.is_recursive) {
result.push_back(filter_to_value);
}
}
return result;
}
template <typename T>
std::optional<T> GetPolicyValue(
const std::vector<FilterToValue<T>>& filter_to_values,
const FieldDescriptor* field) const {
// Return the policy that is not overwritten.
for (int i = filter_to_values.size() - 1; i >= 0; --i) {
if (!filter_to_values[i].filter(field)) continue;
if constexpr (std::is_same_v<T, Domain<std::unique_ptr<Message>>>) {
absl::BitGen gen;
auto domain = filter_to_values[i].value;
auto obj = domain.GetValue(domain.Init(gen));
auto* descriptor = obj->GetDescriptor();
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
descriptor->full_name() == field->message_type()->full_name(),
"Input domain does not match the expected message type. The "
"domain produced a message of type `",
descriptor->full_name(),
"` but the field needs a message of type `",
field->message_type()->full_name(), "`.");
}
return filter_to_values[i].value;
}
return std::nullopt;
}
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
std::vector<FilterToValue<Domain<cpp>>> domains_for_##Camel##_; \
std::vector<FilterToValue<std::function<Domain<cpp>(Domain<cpp>)>>> \
transformers_for_##Camel##_; \
\
public: \
void SetDefaultDomainFor##Camel##s( \
Domain<MakeDependentType<cpp, Message>> domain, \
bool is_recursive = true) { \
domains_for_##Camel##_.push_back({.filter = IncludeAll<FieldDescriptor>(), \
.value = std::move(domain), \
.is_recursive = true}); \
} \
void SetDefaultDomainFor##Camel##s( \
Filter filter, Domain<MakeDependentType<cpp, Message>> domain, \
bool is_recursive = true) { \
domains_for_##Camel##_.push_back({.filter = std::move(filter), \
.value = std::move(domain), \
.is_recursive = true}); \
} \
void SetDomainTransformerFor##Camel##s( \
Filter filter, \
std::function<Domain<MakeDependentType<cpp, Message>>( \
Domain<MakeDependentType<cpp, Message>>)> \
transformer, \
bool is_recursive = true) { \
transformers_for_##Camel##_.push_back({.filter = std::move(filter), \
.value = std::move(transformer), \
.is_recursive = true}); \
} \
std::optional<Domain<MakeDependentType<cpp, Message>>> \
GetDefaultDomainFor##Camel##s(const FieldDescriptor* field) const { \
return GetPolicyValue(domains_for_##Camel##_, field); \
} \
std::optional<std::function<Domain<MakeDependentType<cpp, Message>>( \
Domain<MakeDependentType<cpp, Message>>)>> \
GetDomainTransformerFor##Camel##s(const FieldDescriptor* field) const { \
return GetPolicyValue(transformers_for_##Camel##_, field); \
}
FUZZTEST_INTERNAL_POLICY_MEMBERS(Bool, bool)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Int32, int32_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt32, uint32_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Int64, int64_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt64, uint64_t)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Float, float)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Double, double)
FUZZTEST_INTERNAL_POLICY_MEMBERS(String, std::string)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Enum, int)
FUZZTEST_INTERNAL_POLICY_MEMBERS(Protobuf, std::unique_ptr<Message>)
};
// Domain for std::unique_ptr<Message>, where the prototype is accepted as a
// constructor argument.
template <typename Message>
class ProtobufDomainUntypedImpl
: public DomainBase<ProtobufDomainUntypedImpl<Message>> {
using Descriptor = ProtobufDescriptor<Message>;
using FieldDescriptor = ProtobufFieldDescriptor<Message>;
public:
using corpus_type = absl::flat_hash_map<int, GenericDomainCorpusType>;
using value_type = std::unique_ptr<Message>;
static constexpr bool has_custom_corpus_type = true;
explicit ProtobufDomainUntypedImpl(const Message* prototype)
: prototype_(prototype), policy_(), are_fields_customized_(false) {}
ProtobufDomainUntypedImpl(const ProtobufDomainUntypedImpl& other) {
prototype_ = other.prototype_;
absl::MutexLock l(&other.mutex_);
domains_ = other.domains_;
policy_ = other.policy_;
are_fields_customized_ = other.are_fields_customized_;
}
template <typename T>
static void InitializeFieldValue(absl::BitGenRef prng,
const ProtobufDomainUntypedImpl& self,
const FieldDescriptor* field,
corpus_type& val) {
auto& domain = self.GetSubDomain<T, false>(field);
val[field->number()] = domain.Init(prng);
if (auto* oneof = field->containing_oneof()) {
// Clear the other parts of the oneof. They are unnecessary to
// have and mutating them would have no effect.
for (int i = 0; i < oneof->field_count(); ++i) {
if (i != field->index_in_oneof()) {
val.erase(oneof->field(i)->number());
}
}
}
}
struct InitializeVisitor {
absl::BitGenRef prng;
ProtobufDomainUntypedImpl& self;
corpus_type& val;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
InitializeFieldValue<T>(prng, self, field, val);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
val[field->number()] = domain.Init(prng);
}
};
template <typename PRNG>
corpus_type Init(PRNG& prng) {
FUZZTEST_INTERNAL_CHECK(
are_fields_customized_ || !IsNonTerminatingRecursive(),
"Cannot set recursive fields by default.");
corpus_type val;
// TODO(b/241124202): Use a valid proto with minimum size.
const auto* descriptor = prototype_->GetDescriptor();
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
// We avoid initializing non-required recursive fields by default (if they
// are not explicitly customized). Otherwise, the initialization may never
// terminate. If a proto has only non-required recursive fields, the
// initialization will be deterministic, which violates the assumption on
// domain Init. However, such cases should be extremely rare and breaking
// the assumption would not have severe consequences.
if (!field->is_required() && !are_fields_customized_ &&
IsFieldRecursive(field)) {
continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
return val;
}
struct MutateVisitor {
absl::BitGenRef prng;
bool only_shrink;
ProtobufDomainUntypedImpl& self;
corpus_type& val;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
auto it = val.find(field->number());
const bool is_present = it != val.end();
if (is_present) {
// Mutate the element
domain.Mutate(it->second, prng, only_shrink);
return;
}
// Add the element
if (!only_shrink) {
InitializeFieldValue<T>(prng, self, field, val);
}
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
auto it = val.find(field->number());
const bool is_present = it != val.end();
if (!is_present) {
if (!only_shrink) {
val[field->number()] = domain.Init(prng);
}
} else if (field->is_map()) {
// field of synthetic messages of the form:
//
// message {
// optional key_type key = 1;
// optional value_type value = 2;
// }
//
// The generic code is not doing dedup and it would only happen some
// time after GetValue when the map field is synchronized between
// reflection and codegen. Let's do it eagerly to drop dead entries so
// that we don't keep mutating them later.
//
// TODO(b/231212420): Improve mutation to not add duplicate keys on the
// first place. The current hack is very inefficient.
// Switch the inner domain for maps to use flat_hash_map instead.
corpus_type corpus_copy;
auto& copy = corpus_copy[field->number()] = it->second;
domain.Mutate(copy, prng, only_shrink);
auto v = self.GetValue(corpus_copy);
// We need to roundtrip through serialization to really dedup. The
// reflection API alone doesn't cut it.
v->ParsePartialFromString(v->SerializeAsString());
if (v->GetReflection()->FieldSize(*v, field) ==
domain.GetValue(copy).size()) {
// The number of entries is the same, so accept the change.
it->second = std::move(copy);
}
} else {
domain.Mutate(it->second, prng, only_shrink);
}
}
};
uint64_t CountNumberOfFields(corpus_type& val) {
uint64_t total_weight = 0;
auto* descriptor = prototype_->GetDescriptor();
if (descriptor->field_count() == 0) return total_weight;
for (int i = 0; i < descriptor->field_count(); ++i) {
FieldDescriptor* field = descriptor->field(i);
++total_weight;
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (field->is_repeated()) {
total_weight +=
GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
val_it->second);
} else {
total_weight +=
GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields(
val_it->second);
}
}
}
return total_weight;
}
template <typename PRNG>
uint64_t MutateSelectedField(corpus_type& val, PRNG& prng, bool only_shrink,
uint64_t selected_field_index) {
uint64_t field_counter = 0;
auto* descriptor = prototype_->GetDescriptor();
if (descriptor->field_count() == 0) return field_counter;
for (int i = 0; i < descriptor->field_count(); ++i) {
FieldDescriptor* field = descriptor->field(i);
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(field, MutateVisitor{prng, only_shrink, *this, val});
return field_counter;
}
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (field->is_repeated()) {
field_counter +=
GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
val_it->second, prng, only_shrink,
selected_field_index - field_counter);
} else {
field_counter +=
GetSubDomain<ProtoMessageTag, false>(field).MutateSelectedField(
val_it->second, prng, only_shrink,
selected_field_index - field_counter);
}
}
if (field_counter >= selected_field_index) return field_counter;
}
return field_counter;
}
template <typename PRNG>
void Mutate(corpus_type& val, PRNG& prng, bool only_shrink) {
auto* descriptor = prototype_->GetDescriptor();
if (descriptor->field_count() == 0) return;
// TODO(JunyangShao): Maybe make CountNumberOfFields static.
uint64_t total_weight = CountNumberOfFields(val);
uint64_t selected_weight = absl::Uniform(absl::IntervalClosedClosed, prng,
uint64_t{1}, total_weight);
MutateSelectedField(val, prng, only_shrink, selected_weight);
}
struct GetValueVisitor {
Message& message;
const ProtobufDomainUntypedImpl& self;
const GenericDomainCorpusType& data;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
auto value = domain.GetValue(data);
if (!value.has_value()) {
message.GetReflection()->ClearField(&message, field);
return;
}
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
message.GetReflection()->SetAllocatedMessage(&message, value->release(),
field);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
message.GetReflection()->SetEnumValue(&message, field, *value);
} else {
ProtocolBufferAccess<Message, T>(&message, field).SetField(*value);
}
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
for (auto& v : domain.GetValue(data)) {
message.GetReflection()->AddAllocatedMessage(&message, field,
v.release());
}
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
for (const auto& v : domain.GetValue(data)) {
message.GetReflection()->AddEnumValue(&message, field, v);
}
} else {
for (const auto& v : domain.GetValue(data)) {
ProtocolBufferAccess<Message, T>(&message, field).AddRepeatedField(v);
}
}
}
};
value_type GetValue(const corpus_type& value) const {
value_type out(prototype_->New());
for (auto& [number, data] : value) {
auto* field = GetProtobufField(prototype_, number);
VisitProtobufField(field, GetValueVisitor{*out, *this, data});
}
return out;
}
std::optional<corpus_type> FromValue(const value_type& value) const {
return FromValue(*value);
}
struct FromValueVisitor {
const Message& message;
corpus_type& out;
const ProtobufDomainUntypedImpl& self;
// TODO(sbenzaquen): We might want to try avoid these copies. On the other hand,
// FromValue is not called much so it might be ok.
template <typename T>
bool VisitSingular(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, false>(field);
typename std::decay_t<decltype(domain)>::value_type inner_value;
auto* reflection = message.GetReflection();
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
const auto& child = reflection->GetMessage(message, field);
inner_value = std::unique_ptr<Message>(child.New());
(*inner_value)->CopyFrom(child);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
inner_value = reflection->GetEnum(message, field)->number();
} else {
inner_value =
ProtocolBufferAccess<Message, T>::GetField(message, field);
}
auto inner = domain.FromValue(inner_value);
if (!inner) return false;
out[field->number()] = *std::move(inner);
return true;
}
template <typename T>
bool VisitRepeated(const FieldDescriptor* field) {
auto& domain = self.GetSubDomain<T, true>(field);
typename std::decay_t<decltype(domain)>::value_type inner_value;
auto* reflection = message.GetReflection();
const int size = reflection->FieldSize(message, field);
for (int i = 0; i < size; ++i) {
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
const auto& child = reflection->GetRepeatedMessage(message, field, i);
auto* copy = child.New();
copy->CopyFrom(child);
inner_value.emplace_back(copy);
} else if constexpr (std::is_same_v<T, ProtoEnumTag>) {
inner_value.push_back(
reflection->GetRepeatedEnum(message, field, i)->number());
} else {
inner_value.push_back(
ProtocolBufferAccess<Message, T>::GetRepeatedField(message, field,
i));
}
}
auto inner = domain.FromValue(inner_value);
if (!inner) return false;
out[field->number()] = *std::move(inner);
return true;
}
};
std::optional<corpus_type> FromValue(const Message& value) const {
corpus_type ret;
auto* reflection = value.GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(value, &fields);
for (auto field : fields) {
if (!VisitProtobufField(field, FromValueVisitor{value, ret, *this}))
return std::nullopt;
}
return ret;
}
struct ParseVisitor {
const ProtobufDomainUntypedImpl& self;
const IRObject& obj;
std::optional<GenericDomainCorpusType>& out;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
out = self.GetSubDomain<T, false>(field).ParseCorpus(obj);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
out = self.GetSubDomain<T, true>(field).ParseCorpus(obj);
}
};
std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
corpus_type out;
auto subs = obj.Subs();
if (!subs) return std::nullopt;
for (const auto& sub : *subs) {
auto pair_subs = sub.Subs();
if (!pair_subs || pair_subs->size() != 2) return std::nullopt;
auto number = (*pair_subs)[0].GetScalar<int>();
if (!number) return std::nullopt;
auto* field = GetProtobufField(prototype_, *number);
if (!field) return std::nullopt;
std::optional<GenericDomainCorpusType> inner_parsed;
VisitProtobufField(field,
ParseVisitor{*this, (*pair_subs)[1], inner_parsed});
if (!inner_parsed) return std::nullopt;
out[*number] = *std::move(inner_parsed);
}
return out;
}
struct SerializeVisitor {
const ProtobufDomainUntypedImpl& self;
const GenericDomainCorpusType& corpus_value;
IRObject& out;
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
out = self.GetSubDomain<T, false>(field).SerializeCorpus(corpus_value);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
out = self.GetSubDomain<T, true>(field).SerializeCorpus(corpus_value);
}
};
IRObject SerializeCorpus(const corpus_type& v) const {
IRObject out;
auto& subs = out.MutableSubs();
for (auto& [number, inner] : v) {
auto* field = GetProtobufField(prototype_, number);
FUZZTEST_INTERNAL_CHECK(field, "Field not found by number: ", number);
IRObject& pair = subs.emplace_back();
auto& pair_subs = pair.MutableSubs();
pair_subs.emplace_back(number);
VisitProtobufField(
field, SerializeVisitor{*this, inner, pair_subs.emplace_back()});
}
return out;
}
auto GetPrinter() const { return ProtobufPrinter{}; }
template <typename Inner>
struct WithFieldVisitor {
Inner domain;
ProtobufDomainUntypedImpl& self;
template <typename T, typename DomainT, bool is_repeated>
void ApplyDomain(const FieldDescriptor* field) {
if constexpr (!std::is_constructible_v<DomainT, Inner>) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
(std::is_constructible_v<DomainT, Inner>),
"Input domain does not match field `", field->full_name(),
"` type.");
} else {
if constexpr (std::is_same_v<T, ProtoMessageTag>) {
// Verify that the type matches.
absl::BitGen gen;
auto inner_domain = domain.Inner();
auto obj = inner_domain.GetValue(inner_domain.Init(gen));
auto* descriptor = obj->GetDescriptor();
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
descriptor->full_name() == field->message_type()->full_name(),
"Input domain does not match the expected message type. The "
"domain produced a message of type `",
descriptor->full_name(),
"` but the field needs a message of type `",
field->message_type()->full_name(), "`.");
}
auto field_filter =
[full_name = field->full_name()](const FieldDescriptor* field) {
return field->full_name() == full_name;
};
if constexpr (is_repeated) {
self.GetPolicy().SetMinRepeatedFieldsSize(
field_filter, domain.min_size(), /*is_recursive=*/false);
self.GetPolicy().SetMaxRepeatedFieldsSize(
field_filter, domain.max_size(), /*is_recursive=*/false);
} else {
if (field->is_required()) {
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
domain.policy() == OptionalPolicy::kWithoutNull,
"required field '", field->full_name(),
"' cannot have null values.");
} else {
self.GetPolicy().SetOptionalPolicy(field_filter, domain.policy(),
/*is_recursive=*/false);
}
}
#define FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Camel, cpp, TAG) \
if constexpr (std::is_same_v<T, TAG>) { \
self.GetPolicy().SetDefaultDomainFor##Camel##s( \
field_filter, domain.Inner(), /*is_recursive=*/false); \
self.GetPolicy().SetDomainTransformerFor##Camel##s( \
field_filter, Identity<cpp>(), /*is_recursive=*/false); \
}
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Bool, bool, bool)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Int32, int32_t, int32_t)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(UInt32, uint32_t, uint32_t)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Int64, int64_t, int64_t)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(UInt64, uint64_t, uint64_t)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Float, float, float)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Double, double, double)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(String, std::string,
std::string)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Enum, int, ProtoEnumTag)
FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(
Protobuf, std::unique_ptr<Message>, ProtoMessageTag)
}
}
template <typename T>
void VisitSingular(const FieldDescriptor* field) {
using DomainT = decltype(self.GetDefaultDomainForField<T, false>(field));
ApplyDomain<T, DomainT, /*is_repeated=*/false>(field);
}
template <typename T>
void VisitRepeated(const FieldDescriptor* field) {
using DomainT = decltype(self.GetDefaultDomainForField<T, true>(field));
ApplyDomain<T, DomainT, /*is_repeated=*/true>(field);
}
};
template <typename Inner>
void WithField(std::string_view field_name, Inner&& domain) {
auto* field =
prototype_->GetDescriptor()->FindFieldByName(std::string(field_name));
FUZZTEST_INTERNAL_CHECK_PRECONDITION(field != nullptr,
"Invalid field name '",
std::string(field_name), "'.");
FUZZTEST_INTERNAL_CHECK_PRECONDITION(
!field->containing_oneof(),
"Customizing the domain for oneof fields is not supported yet.");
VisitProtobufField(
field, WithFieldVisitor<Inner&&>{std::forward<Inner>(domain), *this});
are_fields_customized_ = true;
}
void SetPolicy(ProtoPolicy<Message> policy) { policy_ = policy; }
ProtoPolicy<Message>& GetPolicy() { return policy_; }
template <typename T>
auto GetFieldTypeDefaultDomain(absl::string_view field_name) const {
auto* field =
prototype_->GetDescriptor()->FindFieldByName(std::string(field_name));
FUZZTEST_INTERNAL_CHECK_PRECONDITION(field != nullptr,
"Invalid field name '",
std::string(field_name), "'.");
return GetBaseDomainForFieldType<T>(field, /*use_policy=*/true);
}
private:
// Get the existing domain for `field`, if exists.
// Otherwise, create the appropriate `Arbitrary<>` domain for the field and
// return it.
template <typename T, bool is_repeated>
auto& GetSubDomain(const FieldDescriptor* field) const {
// Do the operation under a lock to prevent race conditions in `const`
// methods.
absl::MutexLock l(&mutex_);
auto it = domains_.find(field->number());
if (it == domains_.end()) {
it = domains_
.try_emplace(field->number(), std::in_place,
GetDomainForField<T, is_repeated>(field))
.first;
}
using DomainT = decltype(GetDefaultDomainForField<T, is_repeated>(field));
return it->second.template GetAs<DomainT>();
}
// Simple wrapper that converts a Domain<T> into a Domain<vector<T>>.
template <typename T>
static Domain<std::vector<T>> ModifyDomainForRepeatedFieldRule(
const Domain<T>& d, std::optional<int64_t> min_size,
std::optional<int64_t> max_size) {
auto domain = ContainerOfImpl<std::vector<T>, Domain<T>>(d);
if (min_size.has_value()) {
domain.WithMinSize(*min_size);
}
if (max_size.has_value()) {
domain.WithMaxSize(*max_size);
}
return domain;
}
template <typename T>
static Domain<std::optional<T>> ModifyDomainForOptionalFieldRule(
const Domain<T>& d, OptionalPolicy optional_policy) {
auto result = OptionalOfImpl<std::optional<T>, Domain<T>>(d);
if (optional_policy == OptionalPolicy::kWithoutNull) {
result.SetWithoutNull();
} else if (optional_policy == OptionalPolicy::kAlwaysNull) {
result.SetAlwaysNull();
}
return result;
}
template <typename T>
static Domain<std::optional<T>> ModifyDomainForRequiredFieldRule(
const Domain<T>& d) {
return OptionalOfImpl<std::optional<T>, Domain<T>>(d).SetWithoutNull();
}
// Returns the default "base domain" for a `field` solely based on its type
// (i.e., int32/string), but ignoring the field rule (i.e., the
// repeated/optional/required specifiers).
template <typename T>
auto GetBaseDefaultDomainForFieldType(const FieldDescriptor* field) const {
if constexpr (std::is_same_v<T, std::string>) {
if (field->type() == FieldDescriptor::TYPE_STRING) {
// Can only use UTF-8. For now, simplify as just ASCII.
return Domain<T>(ContainerOfImpl<std::string, InRangeImpl<char>>(
InRangeImpl<char>(char{0}, char{127})));
}
}
if constexpr (std::is_same_v<T, ProtoEnumTag>) {
// For enums, build the list of valid labels.
auto* e = field->enum_type();
std::vector<int> values;
for (int i = 0; i < e->value_count(); ++i) {
values.push_back(e->value(i)->number());
}
// Delay instantiation. The Domain class is not fully defined at this
// point yet, and neither is ElementOfImpl.
using LazyInt = MakeDependentType<int, T>;
return Domain<LazyInt>(ElementOfImpl<LazyInt>(std::move(values)));
} else if constexpr (std::is_same_v<T, ProtoMessageTag>) {
auto result = ProtobufDomainUntypedImpl(
prototype_->GetReflection()->GetMessageFactory()->GetPrototype(
field->message_type()));
result.SetPolicy(policy_);
return Domain<std::unique_ptr<Message>>(result);
} else {
return Domain<T>(ArbitraryImpl<T>());
}
}
template <typename T>
auto GetBaseDomainForFieldType(const FieldDescriptor* field,
bool use_policy) const {
auto default_domain = GetBaseDefaultDomainForFieldType<T>(field);
if (!use_policy) {
return default_domain;
}
#define FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Camel, type) \
if constexpr (std::is_same_v<T, type>) { \
auto default_domain_in_policy = \
policy_.GetDefaultDomainFor##Camel##s(field); \
if (default_domain_in_policy.has_value()) { \
default_domain = std::move(*default_domain_in_policy); \
} \
auto transformer_in_policy = \
policy_.GetDomainTransformerFor##Camel##s(field); \
if (transformer_in_policy.has_value()) { \
default_domain = (*transformer_in_policy)(std::move(default_domain)); \
} \
return default_domain; \
}
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Bool, bool)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int32, int32_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt32, uint32_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int64, int64_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt64, uint64_t)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Float, float)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Double, double)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(String, std::string)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Enum, ProtoEnumTag)
FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Protobuf, ProtoMessageTag)
return GetBaseDefaultDomainForFieldType<T>(field);
}
template <typename T, bool is_repeated>
auto GetDomainForField(const FieldDescriptor* field,
bool use_policy = true) const {
auto domain = GetBaseDomainForFieldType<T>(field, use_policy);
if constexpr (is_repeated) {
return ModifyDomainForRepeatedFieldRule(
std::move(domain),
use_policy ? policy_.GetMinRepeatedFieldSize(field) : std::nullopt,
use_policy ? policy_.GetMaxRepeatedFieldSize(field) : std::nullopt);
} else if (field->is_required()) {
return ModifyDomainForRequiredFieldRule(std::move(domain));
} else {
return ModifyDomainForOptionalFieldRule(
std::move(domain), use_policy ? policy_.GetOptionalPolicy(field)
: OptionalPolicy::kWithNull);
}
}
template <typename T, bool is_repeated>
auto GetDefaultDomainForField(const FieldDescriptor* field) const {
return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false);
}
bool IsNonTerminatingRecursive() {
absl::flat_hash_set<decltype(prototype_->GetDescriptor())> parents;
return IsProtoRecursive(prototype_->GetDescriptor(), parents,
policy_.GetChildrenPolicy(),
/*consider_non_terminating_recursions=*/true);
}
bool IsFieldRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
absl::flat_hash_set<decltype(field->message_type())> parents;
return IsProtoRecursive(field->message_type(), parents,
policy_.GetChildrenPolicy(),
/*consider_non_terminating_recursions=*/false);
}
template <typename Descriptor>
static bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
bool consider_non_terminating_recursions) {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
const auto* child = field->message_type();
if (!child) continue;
if (consider_non_terminating_recursions) {
const bool should_be_set =
field->is_required() ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) == OptionalPolicy::kWithoutNull) ||
(field->is_repeated() &&
policy.GetMinRepeatedFieldSize(field).has_value() &&
*policy.GetMinRepeatedFieldSize(field) > 0);
if (!should_be_set) continue;
} else {
const bool can_be_set =
field->is_required() ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) ||
(field->is_repeated() &&
(!policy.GetMaxRepeatedFieldSize(field).has_value() ||
*policy.GetMaxRepeatedFieldSize(field) > 0));
if (!can_be_set) continue;
}
if (IsProtoRecursive(child, parents, policy,
consider_non_terminating_recursions)) {
return true;
}
}
parents.erase(descriptor);
return false;
}
const Message* prototype_;
mutable absl::Mutex mutex_;
mutable absl::flat_hash_map<int, PolymorphicValue<>> domains_
ABSL_GUARDED_BY(mutex_);
ProtoPolicy<Message> policy_;
bool are_fields_customized_;
};
// Domain for `T` where `T` is a Protobuf message type.
// It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more
// convenient.
template <typename T>
class ProtobufDomainImpl : public DomainBase<ProtobufDomainImpl<T>> {
using Inner = ProtobufDomainUntypedImpl<typename T::Message>;
public:
using value_type = T;
using corpus_type = typename Inner::corpus_type;
using FieldDescriptor = ProtobufFieldDescriptor<typename T::Message>;
static constexpr bool has_custom_corpus_type = true;
template <typename PRNG>
corpus_type Init(PRNG& prng) {
return inner_.Init(prng);
}
uint64_t CountNumberOfFields(corpus_type& val) {
return inner_.CountNumberOfFields(val);
}
uint64_t MutateNumberOfProtoFields(corpus_type& val) {
return inner_.MutateNumberOfProtoFields(val);
}
template <typename PRNG>
void Mutate(corpus_type& val, PRNG& prng, bool only_shrink) {
inner_.Mutate(val, prng, only_shrink);
}
value_type GetValue(const corpus_type& v) const {
auto inner_v = inner_.GetValue(v);
return std::move(static_cast<T&>(*inner_v));
}
std::optional<corpus_type> FromValue(const value_type& value) const {
return inner_.FromValue(value);
}
auto GetPrinter() const { return ProtobufPrinter{}; }
std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
return inner_.ParseCorpus(obj);
}
IRObject SerializeCorpus(const corpus_type& v) const {
return inner_.SerializeCorpus(v);
}
// Provide a conversion to the type that WithMessageField wants.
// Makes it easier on the user.
operator Domain<std::unique_ptr<typename T::Message>>() const {
return inner_;
}
ProtobufDomainImpl&& Self() && { return std::move(*this); }
ProtobufDomainImpl&& WithOptionalFieldsAlwaysSet(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(std::move(filter),
OptionalPolicy::kWithoutNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithOptionalFieldsUnset(
std::function<bool(const FieldDescriptor*)> filter =
IncludeAll<FieldDescriptor>()) && {
inner_.GetPolicy().SetOptionalPolicy(std::move(filter),
OptionalPolicy::kAlwaysNull);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMinSize(int64_t min_size) && {
inner_.GetPolicy().SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(),
min_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMinSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t min_size) && {
inner_.GetPolicy().SetMinRepeatedFieldsSize(std::move(filter), min_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(),
max_size);
return std::move(*this);
}
ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(std::move(filter), max_size);
return std::move(*this);
}
#define FUZZTEST_INTERNAL_WITH_FIELD(Camel, cpp, TAG) \
using Camel##type = MakeDependentType<cpp, T>; \
ProtobufDomainImpl&& With##Camel##Field(std::string_view field, \
Domain<Camel##type> domain)&& { \
inner_.WithField( \
field, OptionalOfImpl<std::optional<Camel##type>, decltype(domain)>( \
std::move(domain))); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldUnset(std::string_view field)&& { \
auto default_domain = \
inner_.template GetFieldTypeDefaultDomain<TAG>(field); \
inner_.WithField( \
field, \
OptionalOfImpl<std::optional<Camel##type>, decltype(default_domain)>( \
std::move(default_domain)) \
.SetAlwaysNull()); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldAlwaysSet( \
std::string_view field, Domain<Camel##type> domain)&& { \
inner_.WithField( \
field, OptionalOfImpl<std::optional<Camel##type>, decltype(domain)>( \
std::move(domain)) \
.SetWithoutNull()); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldAlwaysSet(std::string_view field)&& { \
return std::move(*this).With##Camel##FieldAlwaysSet( \
field, inner_.template GetFieldTypeDefaultDomain<TAG>(field)); \
} \
template <template <typename, typename> typename DomainImpl, typename Inner> \
ProtobufDomainImpl&& WithRepeated##Camel##Field( \
std::string_view field, \
DomainImpl<MakeDependentType<std::vector<cpp>, T>, Inner> domain)&& { \
inner_.WithField(field, std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##Fields(Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \
IncludeAll<FieldDescriptor>(), std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##Fields( \
std::function<bool(const FieldDescriptor*)>&& filter, \
Domain<Camel##type> domain)&& { \
inner_.GetPolicy().SetDefaultDomainFor##Camel##s(std::move(filter), \
std::move(domain)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldsTransformed( \
std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \
transformer)&& { \
inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \
IncludeAll<FieldDescriptor>(), std::move(transformer)); \
return std::move(*this); \
} \
ProtobufDomainImpl&& With##Camel##FieldsTransformed( \
std::function<bool(const FieldDescriptor*)>&& filter, \
std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \
transformer)&& { \
inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \
std::move(filter), std::move(transformer)); \
return std::move(*this); \
}
FUZZTEST_INTERNAL_WITH_FIELD(Bool, bool, bool)
FUZZTEST_INTERNAL_WITH_FIELD(Int32, int32_t, int32_t)
FUZZTEST_INTERNAL_WITH_FIELD(UInt32, uint32_t, uint32_t)
FUZZTEST_INTERNAL_WITH_FIELD(Int64, int64_t, int64_t)
FUZZTEST_INTERNAL_WITH_FIELD(UInt64, uint64_t, uint64_t)
FUZZTEST_INTERNAL_WITH_FIELD(Float, float, float)
FUZZTEST_INTERNAL_WITH_FIELD(Double, double, double)
FUZZTEST_INTERNAL_WITH_FIELD(String, std::string, std::string)
FUZZTEST_INTERNAL_WITH_FIELD(Enum, int, ProtoEnumTag)
FUZZTEST_INTERNAL_WITH_FIELD(Protobuf, std::unique_ptr<typename T::Message>,
ProtoMessageTag)
template <typename ProtoDomain>
ProtobufDomainImpl&& WithRepeatedProtobufField(
std::string_view field,
SequenceContainerOfImpl<std::vector<typename ProtoDomain::value_type>,
ProtoDomain>
domain) && {
auto cast_domain = SequenceContainerOfImpl<
std::vector<std::unique_ptr<typename T::Message>>,
Domain<std::unique_ptr<typename T::Message>>>(domain.Inner());
cast_domain.CopyConstraintsFrom(domain);
inner_.WithField(field, cast_domain);
return std::move(*this);
}
#undef FUZZTEST_INTERNAL_WITH_FIELD
private:
Inner inner_{&T::default_instance()};
};
template <typename T>
class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_v<T>>>
: public ProtobufDomainImpl<T> {};
template <typename T>
class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_enum_v<T>>>
: public DomainBase<ArbitraryImpl<T>> {
public:
using value_type = T;
template <typename PRNG>
value_type Init(PRNG& prng) {
const int index = absl::Uniform(prng, 0, descriptor()->value_count());
return static_cast<T>(descriptor()->value(index)->number());
}
template <typename PRNG>
void Mutate(value_type& val, PRNG& prng, bool only_shrink) {
if (only_shrink) {
std::vector<int> numbers;
for (int i = 0; i < descriptor()->value_count(); ++i) {
if (int n = descriptor()->value(i)->number(); n < val)
numbers.push_back(n);
}
if (numbers.empty()) return;
size_t idx = absl::Uniform<size_t>(prng, 0, numbers.size());
val = static_cast<T>(numbers[idx]);
return;
} else if (descriptor()->value_count() == 1) {
return;
}
// Make sure Mutate really mutates.
const T prev = val;
do {
val = Init(prng);
} while (val == prev);
}
auto GetPrinter() const {
return ProtobufEnumPrinter<decltype(descriptor())>{descriptor()};
}
private:
auto descriptor() const {
static auto const descriptor_ = google::protobuf::GetEnumDescriptor<T>();
return descriptor_;
}
};
} // namespace fuzztest::internal
#endif // FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_