| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #ifndef FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_ |
| #define FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_ |
| |
| #include <cstddef> |
| #include <cstdint> |
| #include <functional> |
| #include <memory> |
| #include <optional> |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/base/thread_annotations.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/random/bit_gen_ref.h" |
| #include "absl/random/random.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/types/span.h" |
| #include "./fuzztest/internal/domain.h" |
| #include "./fuzztest/internal/logging.h" |
| #include "./fuzztest/internal/meta.h" |
| #include "./fuzztest/internal/polymorphic_value.h" |
| #include "./fuzztest/internal/serialization.h" |
| #include "./fuzztest/internal/type_support.h" |
| |
| namespace google::protobuf { |
| class EnumDescriptor; |
| |
| template <typename E> |
| const EnumDescriptor* GetEnumDescriptor(); |
| } // namespace google::protobuf |
| |
| namespace fuzztest::internal { |
| |
| // Sniff the API to get the types we need without naming them directly. |
| // This allows for a soft dependency on proto without having to #include its |
| // headers. |
| template <typename Message> |
| using ProtobufReflection = |
| std::remove_pointer_t<decltype(std::declval<Message>().GetReflection())>; |
| template <typename Message> |
| using ProtobufDescriptor = |
| std::remove_pointer_t<decltype(std::declval<Message>().GetDescriptor())>; |
| template <typename Message> |
| using ProtobufFieldDescriptor = std::remove_pointer_t< |
| decltype(std::declval<Message>().GetDescriptor()->field(0))>; |
| |
| template <typename Message, typename V> |
| class ProtocolBufferAccess; |
| |
| #define FUZZTEST_INTERNAL_PROTO_ACCESS_(cpp_type, Camel) \ |
| template <typename Message> \ |
| class ProtocolBufferAccess<Message, cpp_type> { \ |
| public: \ |
| using Reflection = ProtobufReflection<Message>; \ |
| using FieldDescriptor = ProtobufFieldDescriptor<Message>; \ |
| using value_type = cpp_type; \ |
| \ |
| ProtocolBufferAccess(Message* message, const FieldDescriptor* field) \ |
| : message_(message), \ |
| reflection_(message->GetReflection()), \ |
| field_(field) {} \ |
| \ |
| static auto GetField(const Message& message, \ |
| const FieldDescriptor* field) { \ |
| return message.GetReflection()->Get##Camel(message, field); \ |
| } \ |
| void SetField(const cpp_type& value) { \ |
| return reflection_->Set##Camel(message_, field_, value); \ |
| } \ |
| static auto GetRepeatedField(const Message& message, \ |
| const FieldDescriptor* field, int index) { \ |
| return message.GetReflection()->GetRepeated##Camel(message, field, \ |
| index); \ |
| } \ |
| void SetRepeatedField(int index, const cpp_type& value) { \ |
| return reflection_->SetRepeated##Camel(message_, field_, index, value); \ |
| } \ |
| void AddRepeatedField(const cpp_type& value) { \ |
| return reflection_->Add##Camel(message_, field_, value); \ |
| } \ |
| \ |
| private: \ |
| Message* message_; \ |
| Reflection* reflection_; \ |
| const FieldDescriptor* field_; \ |
| } |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(int32_t, Int32); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(uint32_t, UInt32); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(int64_t, Int64); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(uint64_t, UInt64); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(float, Float); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(double, Double); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(std::string, String); |
| FUZZTEST_INTERNAL_PROTO_ACCESS_(bool, Bool); |
| #undef FUZZTEST_INTERNAL_PROTO_ACCESS_ |
| |
| struct ProtoEnumTag; |
| struct ProtoMessageTag; |
| |
| // Dynamic to static dispatch visitation. |
| // It will invoke: |
| // visitor.VisitSingular<type>(field) // for singular fields |
| // visitor.VisitRepeated<type>(field) // for repeated fields |
| // where `type` is: |
| // - For bool, integrals, floating point and string: their C++ type. |
| // - For enum: the tag type ProtoEnumTag. |
| // - For message: the tag type ProtoMessageTag. |
| template <typename FieldDescriptor, typename Visitor> |
| auto VisitProtobufField(const FieldDescriptor* field, Visitor visitor) { |
| if (field->is_repeated()) { |
| switch (field->cpp_type()) { |
| case FieldDescriptor::CPPTYPE_BOOL: |
| return visitor.template VisitRepeated<bool>(field); |
| case FieldDescriptor::CPPTYPE_INT32: |
| return visitor.template VisitRepeated<int32_t>(field); |
| case FieldDescriptor::CPPTYPE_UINT32: |
| return visitor.template VisitRepeated<uint32_t>(field); |
| case FieldDescriptor::CPPTYPE_INT64: |
| return visitor.template VisitRepeated<int64_t>(field); |
| case FieldDescriptor::CPPTYPE_UINT64: |
| return visitor.template VisitRepeated<uint64_t>(field); |
| case FieldDescriptor::CPPTYPE_FLOAT: |
| return visitor.template VisitRepeated<float>(field); |
| case FieldDescriptor::CPPTYPE_DOUBLE: |
| return visitor.template VisitRepeated<double>(field); |
| case FieldDescriptor::CPPTYPE_STRING: |
| return visitor.template VisitRepeated<std::string>(field); |
| case FieldDescriptor::CPPTYPE_ENUM: |
| return visitor.template VisitRepeated<ProtoEnumTag>(field); |
| case FieldDescriptor::CPPTYPE_MESSAGE: |
| return visitor.template VisitRepeated<ProtoMessageTag>(field); |
| } |
| } else { |
| switch (field->cpp_type()) { |
| case FieldDescriptor::CPPTYPE_BOOL: |
| return visitor.template VisitSingular<bool>(field); |
| case FieldDescriptor::CPPTYPE_INT32: |
| return visitor.template VisitSingular<int32_t>(field); |
| case FieldDescriptor::CPPTYPE_UINT32: |
| return visitor.template VisitSingular<uint32_t>(field); |
| case FieldDescriptor::CPPTYPE_INT64: |
| return visitor.template VisitSingular<int64_t>(field); |
| case FieldDescriptor::CPPTYPE_UINT64: |
| return visitor.template VisitSingular<uint64_t>(field); |
| case FieldDescriptor::CPPTYPE_FLOAT: |
| return visitor.template VisitSingular<float>(field); |
| case FieldDescriptor::CPPTYPE_DOUBLE: |
| return visitor.template VisitSingular<double>(field); |
| case FieldDescriptor::CPPTYPE_STRING: |
| return visitor.template VisitSingular<std::string>(field); |
| case FieldDescriptor::CPPTYPE_ENUM: |
| return visitor.template VisitSingular<ProtoEnumTag>(field); |
| case FieldDescriptor::CPPTYPE_MESSAGE: |
| return visitor.template VisitSingular<ProtoMessageTag>(field); |
| } |
| } |
| } |
| |
| template <typename Message> |
| auto GetProtobufField(const Message* prototype, int number) { |
| auto* field = prototype->GetDescriptor()->FindFieldByNumber(number); |
| if (field == nullptr) { |
| field = prototype->GetReflection()->FindKnownExtensionByNumber(number); |
| } |
| return field; |
| } |
| |
| template <typename T> |
| auto IncludeAll() { |
| return [](const T* field) { return true; }; |
| } |
| |
| template <typename T> |
| std::function<Domain<T>(Domain<T>)> Identity() { |
| return [](Domain<T> domain) -> Domain<T> { return domain; }; |
| } |
| |
| template <typename Message> |
| class ProtoPolicy { |
| using FieldDescriptor = ProtobufFieldDescriptor<Message>; |
| using Filter = std::function<bool(const FieldDescriptor*)>; |
| |
| public: |
| ProtoPolicy() |
| : optional_policies_({{.filter = IncludeAll<FieldDescriptor>(), |
| .value = OptionalPolicy::kWithNull}}) {} |
| |
| ProtoPolicy GetChildrenPolicy() const { |
| ProtoPolicy children_policy; |
| children_policy.optional_policies_ = |
| CopyRecursivePolicyValues(optional_policies_); |
| children_policy.min_repeated_fields_sizes_ = |
| CopyRecursivePolicyValues(min_repeated_fields_sizes_); |
| children_policy.max_repeated_fields_sizes_ = |
| CopyRecursivePolicyValues(max_repeated_fields_sizes_); |
| |
| children_policy.domains_for_Bool_ = |
| CopyRecursivePolicyValues(domains_for_Bool_); |
| children_policy.domains_for_Int32_ = |
| CopyRecursivePolicyValues(domains_for_Int32_); |
| children_policy.domains_for_UInt32_ = |
| CopyRecursivePolicyValues(domains_for_UInt32_); |
| children_policy.domains_for_Int64_ = |
| CopyRecursivePolicyValues(domains_for_Int64_); |
| children_policy.domains_for_UInt64_ = |
| CopyRecursivePolicyValues(domains_for_UInt64_); |
| children_policy.domains_for_Float_ = |
| CopyRecursivePolicyValues(domains_for_Float_); |
| children_policy.domains_for_Double_ = |
| CopyRecursivePolicyValues(domains_for_Double_); |
| children_policy.domains_for_String_ = |
| CopyRecursivePolicyValues(domains_for_String_); |
| children_policy.domains_for_Enum_ = |
| CopyRecursivePolicyValues(domains_for_Enum_); |
| children_policy.domains_for_Protobuf_ = |
| CopyRecursivePolicyValues(domains_for_Protobuf_); |
| children_policy.transformers_for_Bool_ = |
| CopyRecursivePolicyValues(transformers_for_Bool_); |
| children_policy.transformers_for_Int32_ = |
| CopyRecursivePolicyValues(transformers_for_Int32_); |
| children_policy.transformers_for_UInt32_ = |
| CopyRecursivePolicyValues(transformers_for_UInt32_); |
| children_policy.transformers_for_Int64_ = |
| CopyRecursivePolicyValues(transformers_for_Int64_); |
| children_policy.transformers_for_UInt64_ = |
| CopyRecursivePolicyValues(transformers_for_UInt64_); |
| children_policy.transformers_for_Float_ = |
| CopyRecursivePolicyValues(transformers_for_Float_); |
| children_policy.transformers_for_Double_ = |
| CopyRecursivePolicyValues(transformers_for_Double_); |
| children_policy.transformers_for_String_ = |
| CopyRecursivePolicyValues(transformers_for_String_); |
| children_policy.transformers_for_Enum_ = |
| CopyRecursivePolicyValues(transformers_for_Enum_); |
| children_policy.transformers_for_Protobuf_ = |
| CopyRecursivePolicyValues(transformers_for_Protobuf_); |
| return children_policy; |
| } |
| |
| void SetOptionalPolicy(OptionalPolicy optional_policy, |
| bool is_recursive = true) { |
| SetOptionalPolicy(IncludeAll<FieldDescriptor>(), optional_policy, |
| is_recursive); |
| } |
| |
| void SetOptionalPolicy(Filter filter, OptionalPolicy optional_policy, |
| bool is_recursive = true) { |
| optional_policies_.push_back({.filter = std::move(filter), |
| .value = optional_policy, |
| .is_recursive = is_recursive}); |
| } |
| |
| void SetMinRepeatedFieldsSize(int64_t min_size, bool is_recursive = true) { |
| SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), min_size, |
| is_recursive); |
| } |
| |
| void SetMinRepeatedFieldsSize(Filter filter, int64_t min_size, |
| bool is_recursive = true) { |
| min_repeated_fields_sizes_.push_back({.filter = std::move(filter), |
| .value = min_size, |
| .is_recursive = is_recursive}); |
| } |
| |
| void SetMaxRepeatedFieldsSize(int64_t max_size, bool is_recursive = true) { |
| SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), max_size, |
| is_recursive); |
| } |
| |
| void SetMaxRepeatedFieldsSize(Filter filter, int64_t max_size, |
| bool is_recursive = true) { |
| max_repeated_fields_sizes_.push_back({.filter = std::move(filter), |
| .value = max_size, |
| .is_recursive = is_recursive}); |
| } |
| |
| OptionalPolicy GetOptionalPolicy(const FieldDescriptor* field) const { |
| FUZZTEST_INTERNAL_CHECK( |
| field->is_optional(), |
| "GetOptionalPolicy should apply to optional fields only!"); |
| std::optional<OptionalPolicy> result = |
| GetPolicyValue(optional_policies_, field); |
| FUZZTEST_INTERNAL_CHECK(result.has_value(), "optional policy is not set!"); |
| return *result; |
| } |
| |
| std::optional<int64_t> GetMinRepeatedFieldSize( |
| const FieldDescriptor* field) const { |
| FUZZTEST_INTERNAL_CHECK( |
| field->is_repeated(), |
| "GetMinRepeatedFieldSize should apply to repeated fields only!"); |
| return GetPolicyValue(min_repeated_fields_sizes_, field); |
| } |
| |
| std::optional<int64_t> GetMaxRepeatedFieldSize( |
| const FieldDescriptor* field) const { |
| FUZZTEST_INTERNAL_CHECK( |
| field->is_repeated(), |
| "GetMaxRepeatedFieldSize should apply to repeated fields only!"); |
| return GetPolicyValue(max_repeated_fields_sizes_, field); |
| } |
| |
| private: |
| template <typename T> |
| struct FilterToValue { |
| Filter filter; |
| T value; |
| bool is_recursive = true; |
| }; |
| |
| template <typename T> |
| std::vector<FilterToValue<T>> CopyRecursivePolicyValues( |
| const std::vector<FilterToValue<T>>& filter_to_values) const { |
| std::vector<FilterToValue<T>> result; |
| for (const auto& filter_to_value : filter_to_values) { |
| if (filter_to_value.is_recursive) { |
| result.push_back(filter_to_value); |
| } |
| } |
| return result; |
| } |
| |
| template <typename T> |
| std::optional<T> GetPolicyValue( |
| const std::vector<FilterToValue<T>>& filter_to_values, |
| const FieldDescriptor* field) const { |
| // Return the policy that is not overwritten. |
| for (int i = filter_to_values.size() - 1; i >= 0; --i) { |
| if (!filter_to_values[i].filter(field)) continue; |
| if constexpr (std::is_same_v<T, Domain<std::unique_ptr<Message>>>) { |
| absl::BitGen gen; |
| auto domain = filter_to_values[i].value; |
| auto obj = domain.GetValue(domain.Init(gen)); |
| auto* descriptor = obj->GetDescriptor(); |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION( |
| descriptor->full_name() == field->message_type()->full_name(), |
| "Input domain does not match the expected message type. The " |
| "domain produced a message of type `", |
| descriptor->full_name(), |
| "` but the field needs a message of type `", |
| field->message_type()->full_name(), "`."); |
| } |
| return filter_to_values[i].value; |
| } |
| return std::nullopt; |
| } |
| std::vector<FilterToValue<OptionalPolicy>> optional_policies_; |
| std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_; |
| std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_; |
| |
| #define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \ |
| private: \ |
| std::vector<FilterToValue<Domain<cpp>>> domains_for_##Camel##_; \ |
| std::vector<FilterToValue<std::function<Domain<cpp>(Domain<cpp>)>>> \ |
| transformers_for_##Camel##_; \ |
| \ |
| public: \ |
| void SetDefaultDomainFor##Camel##s( \ |
| Domain<MakeDependentType<cpp, Message>> domain, \ |
| bool is_recursive = true) { \ |
| domains_for_##Camel##_.push_back({.filter = IncludeAll<FieldDescriptor>(), \ |
| .value = std::move(domain), \ |
| .is_recursive = true}); \ |
| } \ |
| void SetDefaultDomainFor##Camel##s( \ |
| Filter filter, Domain<MakeDependentType<cpp, Message>> domain, \ |
| bool is_recursive = true) { \ |
| domains_for_##Camel##_.push_back({.filter = std::move(filter), \ |
| .value = std::move(domain), \ |
| .is_recursive = true}); \ |
| } \ |
| void SetDomainTransformerFor##Camel##s( \ |
| Filter filter, \ |
| std::function<Domain<MakeDependentType<cpp, Message>>( \ |
| Domain<MakeDependentType<cpp, Message>>)> \ |
| transformer, \ |
| bool is_recursive = true) { \ |
| transformers_for_##Camel##_.push_back({.filter = std::move(filter), \ |
| .value = std::move(transformer), \ |
| .is_recursive = true}); \ |
| } \ |
| std::optional<Domain<MakeDependentType<cpp, Message>>> \ |
| GetDefaultDomainFor##Camel##s(const FieldDescriptor* field) const { \ |
| return GetPolicyValue(domains_for_##Camel##_, field); \ |
| } \ |
| std::optional<std::function<Domain<MakeDependentType<cpp, Message>>( \ |
| Domain<MakeDependentType<cpp, Message>>)>> \ |
| GetDomainTransformerFor##Camel##s(const FieldDescriptor* field) const { \ |
| return GetPolicyValue(transformers_for_##Camel##_, field); \ |
| } |
| |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Bool, bool) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Int32, int32_t) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt32, uint32_t) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Int64, int64_t) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(UInt64, uint64_t) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Float, float) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Double, double) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(String, std::string) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Enum, int) |
| FUZZTEST_INTERNAL_POLICY_MEMBERS(Protobuf, std::unique_ptr<Message>) |
| }; |
| |
| // Domain for std::unique_ptr<Message>, where the prototype is accepted as a |
| // constructor argument. |
| template <typename Message> |
| class ProtobufDomainUntypedImpl |
| : public DomainBase<ProtobufDomainUntypedImpl<Message>> { |
| using Descriptor = ProtobufDescriptor<Message>; |
| using FieldDescriptor = ProtobufFieldDescriptor<Message>; |
| |
| public: |
| using corpus_type = absl::flat_hash_map<int, GenericDomainCorpusType>; |
| using value_type = std::unique_ptr<Message>; |
| static constexpr bool has_custom_corpus_type = true; |
| |
| explicit ProtobufDomainUntypedImpl(const Message* prototype) |
| : prototype_(prototype), policy_(), are_fields_customized_(false) {} |
| |
| ProtobufDomainUntypedImpl(const ProtobufDomainUntypedImpl& other) { |
| prototype_ = other.prototype_; |
| absl::MutexLock l(&other.mutex_); |
| domains_ = other.domains_; |
| policy_ = other.policy_; |
| are_fields_customized_ = other.are_fields_customized_; |
| } |
| |
| template <typename T> |
| static void InitializeFieldValue(absl::BitGenRef prng, |
| const ProtobufDomainUntypedImpl& self, |
| const FieldDescriptor* field, |
| corpus_type& val) { |
| auto& domain = self.GetSubDomain<T, false>(field); |
| val[field->number()] = domain.Init(prng); |
| |
| if (auto* oneof = field->containing_oneof()) { |
| // Clear the other parts of the oneof. They are unnecessary to |
| // have and mutating them would have no effect. |
| for (int i = 0; i < oneof->field_count(); ++i) { |
| if (i != field->index_in_oneof()) { |
| val.erase(oneof->field(i)->number()); |
| } |
| } |
| } |
| } |
| |
| struct InitializeVisitor { |
| absl::BitGenRef prng; |
| ProtobufDomainUntypedImpl& self; |
| corpus_type& val; |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| InitializeFieldValue<T>(prng, self, field, val); |
| } |
| |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, true>(field); |
| val[field->number()] = domain.Init(prng); |
| } |
| }; |
| |
| template <typename PRNG> |
| corpus_type Init(PRNG& prng) { |
| FUZZTEST_INTERNAL_CHECK( |
| are_fields_customized_ || !IsNonTerminatingRecursive(), |
| "Cannot set recursive fields by default."); |
| corpus_type val; |
| |
| // TODO(b/241124202): Use a valid proto with minimum size. |
| const auto* descriptor = prototype_->GetDescriptor(); |
| for (int i = 0; i < descriptor->field_count(); ++i) { |
| const auto* field = descriptor->field(i); |
| // We avoid initializing non-required recursive fields by default (if they |
| // are not explicitly customized). Otherwise, the initialization may never |
| // terminate. If a proto has only non-required recursive fields, the |
| // initialization will be deterministic, which violates the assumption on |
| // domain Init. However, such cases should be extremely rare and breaking |
| // the assumption would not have severe consequences. |
| if (!field->is_required() && !are_fields_customized_ && |
| IsFieldRecursive(field)) { |
| continue; |
| } |
| VisitProtobufField(field, InitializeVisitor{prng, *this, val}); |
| } |
| return val; |
| } |
| |
| struct MutateVisitor { |
| absl::BitGenRef prng; |
| bool only_shrink; |
| ProtobufDomainUntypedImpl& self; |
| corpus_type& val; |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, false>(field); |
| auto it = val.find(field->number()); |
| const bool is_present = it != val.end(); |
| |
| if (is_present) { |
| // Mutate the element |
| domain.Mutate(it->second, prng, only_shrink); |
| return; |
| } |
| |
| // Add the element |
| if (!only_shrink) { |
| InitializeFieldValue<T>(prng, self, field, val); |
| } |
| } |
| |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, true>(field); |
| auto it = val.find(field->number()); |
| const bool is_present = it != val.end(); |
| |
| if (!is_present) { |
| if (!only_shrink) { |
| val[field->number()] = domain.Init(prng); |
| } |
| } else if (field->is_map()) { |
| // field of synthetic messages of the form: |
| // |
| // message { |
| // optional key_type key = 1; |
| // optional value_type value = 2; |
| // } |
| // |
| // The generic code is not doing dedup and it would only happen some |
| // time after GetValue when the map field is synchronized between |
| // reflection and codegen. Let's do it eagerly to drop dead entries so |
| // that we don't keep mutating them later. |
| // |
| // TODO(b/231212420): Improve mutation to not add duplicate keys on the |
| // first place. The current hack is very inefficient. |
| // Switch the inner domain for maps to use flat_hash_map instead. |
| corpus_type corpus_copy; |
| auto& copy = corpus_copy[field->number()] = it->second; |
| domain.Mutate(copy, prng, only_shrink); |
| auto v = self.GetValue(corpus_copy); |
| // We need to roundtrip through serialization to really dedup. The |
| // reflection API alone doesn't cut it. |
| v->ParsePartialFromString(v->SerializeAsString()); |
| if (v->GetReflection()->FieldSize(*v, field) == |
| domain.GetValue(copy).size()) { |
| // The number of entries is the same, so accept the change. |
| it->second = std::move(copy); |
| } |
| } else { |
| domain.Mutate(it->second, prng, only_shrink); |
| } |
| } |
| }; |
| |
| uint64_t CountNumberOfFields(corpus_type& val) { |
| uint64_t total_weight = 0; |
| auto* descriptor = prototype_->GetDescriptor(); |
| if (descriptor->field_count() == 0) return total_weight; |
| |
| for (int i = 0; i < descriptor->field_count(); ++i) { |
| FieldDescriptor* field = descriptor->field(i); |
| ++total_weight; |
| |
| if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
| auto val_it = val.find(field->number()); |
| if (val_it == val.end()) continue; |
| if (field->is_repeated()) { |
| total_weight += |
| GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields( |
| val_it->second); |
| } else { |
| total_weight += |
| GetSubDomain<ProtoMessageTag, false>(field).CountNumberOfFields( |
| val_it->second); |
| } |
| } |
| } |
| return total_weight; |
| } |
| |
| template <typename PRNG> |
| uint64_t MutateSelectedField(corpus_type& val, PRNG& prng, bool only_shrink, |
| uint64_t selected_field_index) { |
| uint64_t field_counter = 0; |
| auto* descriptor = prototype_->GetDescriptor(); |
| if (descriptor->field_count() == 0) return field_counter; |
| |
| for (int i = 0; i < descriptor->field_count(); ++i) { |
| FieldDescriptor* field = descriptor->field(i); |
| ++field_counter; |
| if (field_counter == selected_field_index) { |
| VisitProtobufField(field, MutateVisitor{prng, only_shrink, *this, val}); |
| return field_counter; |
| } |
| |
| if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
| auto val_it = val.find(field->number()); |
| if (val_it == val.end()) continue; |
| if (field->is_repeated()) { |
| field_counter += |
| GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField( |
| val_it->second, prng, only_shrink, |
| selected_field_index - field_counter); |
| } else { |
| field_counter += |
| GetSubDomain<ProtoMessageTag, false>(field).MutateSelectedField( |
| val_it->second, prng, only_shrink, |
| selected_field_index - field_counter); |
| } |
| } |
| if (field_counter >= selected_field_index) return field_counter; |
| } |
| return field_counter; |
| } |
| |
| template <typename PRNG> |
| void Mutate(corpus_type& val, PRNG& prng, bool only_shrink) { |
| auto* descriptor = prototype_->GetDescriptor(); |
| if (descriptor->field_count() == 0) return; |
| // TODO(JunyangShao): Maybe make CountNumberOfFields static. |
| uint64_t total_weight = CountNumberOfFields(val); |
| uint64_t selected_weight = absl::Uniform(absl::IntervalClosedClosed, prng, |
| uint64_t{1}, total_weight); |
| MutateSelectedField(val, prng, only_shrink, selected_weight); |
| } |
| |
| struct GetValueVisitor { |
| Message& message; |
| const ProtobufDomainUntypedImpl& self; |
| const GenericDomainCorpusType& data; |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, false>(field); |
| auto value = domain.GetValue(data); |
| if (!value.has_value()) { |
| message.GetReflection()->ClearField(&message, field); |
| return; |
| } |
| if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| message.GetReflection()->SetAllocatedMessage(&message, value->release(), |
| field); |
| } else if constexpr (std::is_same_v<T, ProtoEnumTag>) { |
| message.GetReflection()->SetEnumValue(&message, field, *value); |
| } else { |
| ProtocolBufferAccess<Message, T>(&message, field).SetField(*value); |
| } |
| } |
| |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, true>(field); |
| if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| for (auto& v : domain.GetValue(data)) { |
| message.GetReflection()->AddAllocatedMessage(&message, field, |
| v.release()); |
| } |
| } else if constexpr (std::is_same_v<T, ProtoEnumTag>) { |
| for (const auto& v : domain.GetValue(data)) { |
| message.GetReflection()->AddEnumValue(&message, field, v); |
| } |
| } else { |
| for (const auto& v : domain.GetValue(data)) { |
| ProtocolBufferAccess<Message, T>(&message, field).AddRepeatedField(v); |
| } |
| } |
| } |
| }; |
| |
| value_type GetValue(const corpus_type& value) const { |
| value_type out(prototype_->New()); |
| |
| for (auto& [number, data] : value) { |
| auto* field = GetProtobufField(prototype_, number); |
| VisitProtobufField(field, GetValueVisitor{*out, *this, data}); |
| } |
| |
| return out; |
| } |
| |
| std::optional<corpus_type> FromValue(const value_type& value) const { |
| return FromValue(*value); |
| } |
| |
| struct FromValueVisitor { |
| const Message& message; |
| corpus_type& out; |
| const ProtobufDomainUntypedImpl& self; |
| |
| // TODO(sbenzaquen): We might want to try avoid these copies. On the other hand, |
| // FromValue is not called much so it might be ok. |
| template <typename T> |
| bool VisitSingular(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, false>(field); |
| typename std::decay_t<decltype(domain)>::value_type inner_value; |
| auto* reflection = message.GetReflection(); |
| if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| const auto& child = reflection->GetMessage(message, field); |
| inner_value = std::unique_ptr<Message>(child.New()); |
| (*inner_value)->CopyFrom(child); |
| } else if constexpr (std::is_same_v<T, ProtoEnumTag>) { |
| inner_value = reflection->GetEnum(message, field)->number(); |
| } else { |
| inner_value = |
| ProtocolBufferAccess<Message, T>::GetField(message, field); |
| } |
| auto inner = domain.FromValue(inner_value); |
| if (!inner) return false; |
| out[field->number()] = *std::move(inner); |
| return true; |
| } |
| |
| template <typename T> |
| bool VisitRepeated(const FieldDescriptor* field) { |
| auto& domain = self.GetSubDomain<T, true>(field); |
| typename std::decay_t<decltype(domain)>::value_type inner_value; |
| auto* reflection = message.GetReflection(); |
| const int size = reflection->FieldSize(message, field); |
| for (int i = 0; i < size; ++i) { |
| if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| const auto& child = reflection->GetRepeatedMessage(message, field, i); |
| auto* copy = child.New(); |
| copy->CopyFrom(child); |
| inner_value.emplace_back(copy); |
| } else if constexpr (std::is_same_v<T, ProtoEnumTag>) { |
| inner_value.push_back( |
| reflection->GetRepeatedEnum(message, field, i)->number()); |
| } else { |
| inner_value.push_back( |
| ProtocolBufferAccess<Message, T>::GetRepeatedField(message, field, |
| i)); |
| } |
| } |
| |
| auto inner = domain.FromValue(inner_value); |
| if (!inner) return false; |
| out[field->number()] = *std::move(inner); |
| return true; |
| } |
| }; |
| |
| std::optional<corpus_type> FromValue(const Message& value) const { |
| corpus_type ret; |
| auto* reflection = value.GetReflection(); |
| std::vector<const FieldDescriptor*> fields; |
| reflection->ListFields(value, &fields); |
| for (auto field : fields) { |
| if (!VisitProtobufField(field, FromValueVisitor{value, ret, *this})) |
| return std::nullopt; |
| } |
| return ret; |
| } |
| |
| struct ParseVisitor { |
| const ProtobufDomainUntypedImpl& self; |
| const IRObject& obj; |
| std::optional<GenericDomainCorpusType>& out; |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| out = self.GetSubDomain<T, false>(field).ParseCorpus(obj); |
| } |
| |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| out = self.GetSubDomain<T, true>(field).ParseCorpus(obj); |
| } |
| }; |
| |
| std::optional<corpus_type> ParseCorpus(const IRObject& obj) const { |
| corpus_type out; |
| auto subs = obj.Subs(); |
| if (!subs) return std::nullopt; |
| for (const auto& sub : *subs) { |
| auto pair_subs = sub.Subs(); |
| if (!pair_subs || pair_subs->size() != 2) return std::nullopt; |
| auto number = (*pair_subs)[0].GetScalar<int>(); |
| if (!number) return std::nullopt; |
| auto* field = GetProtobufField(prototype_, *number); |
| if (!field) return std::nullopt; |
| |
| std::optional<GenericDomainCorpusType> inner_parsed; |
| VisitProtobufField(field, |
| ParseVisitor{*this, (*pair_subs)[1], inner_parsed}); |
| if (!inner_parsed) return std::nullopt; |
| out[*number] = *std::move(inner_parsed); |
| } |
| |
| return out; |
| } |
| |
| struct SerializeVisitor { |
| const ProtobufDomainUntypedImpl& self; |
| const GenericDomainCorpusType& corpus_value; |
| IRObject& out; |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| out = self.GetSubDomain<T, false>(field).SerializeCorpus(corpus_value); |
| } |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| out = self.GetSubDomain<T, true>(field).SerializeCorpus(corpus_value); |
| } |
| }; |
| |
| IRObject SerializeCorpus(const corpus_type& v) const { |
| IRObject out; |
| auto& subs = out.MutableSubs(); |
| for (auto& [number, inner] : v) { |
| auto* field = GetProtobufField(prototype_, number); |
| FUZZTEST_INTERNAL_CHECK(field, "Field not found by number: ", number); |
| IRObject& pair = subs.emplace_back(); |
| auto& pair_subs = pair.MutableSubs(); |
| pair_subs.emplace_back(number); |
| VisitProtobufField( |
| field, SerializeVisitor{*this, inner, pair_subs.emplace_back()}); |
| } |
| return out; |
| } |
| |
| auto GetPrinter() const { return ProtobufPrinter{}; } |
| |
| template <typename Inner> |
| struct WithFieldVisitor { |
| Inner domain; |
| ProtobufDomainUntypedImpl& self; |
| |
| template <typename T, typename DomainT, bool is_repeated> |
| void ApplyDomain(const FieldDescriptor* field) { |
| if constexpr (!std::is_constructible_v<DomainT, Inner>) { |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION( |
| (std::is_constructible_v<DomainT, Inner>), |
| "Input domain does not match field `", field->full_name(), |
| "` type."); |
| } else { |
| if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| // Verify that the type matches. |
| absl::BitGen gen; |
| auto inner_domain = domain.Inner(); |
| auto obj = inner_domain.GetValue(inner_domain.Init(gen)); |
| auto* descriptor = obj->GetDescriptor(); |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION( |
| descriptor->full_name() == field->message_type()->full_name(), |
| "Input domain does not match the expected message type. The " |
| "domain produced a message of type `", |
| descriptor->full_name(), |
| "` but the field needs a message of type `", |
| field->message_type()->full_name(), "`."); |
| } |
| auto field_filter = |
| [full_name = field->full_name()](const FieldDescriptor* field) { |
| return field->full_name() == full_name; |
| }; |
| if constexpr (is_repeated) { |
| self.GetPolicy().SetMinRepeatedFieldsSize( |
| field_filter, domain.min_size(), /*is_recursive=*/false); |
| self.GetPolicy().SetMaxRepeatedFieldsSize( |
| field_filter, domain.max_size(), /*is_recursive=*/false); |
| } else { |
| if (field->is_required()) { |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION( |
| domain.policy() == OptionalPolicy::kWithoutNull, |
| "required field '", field->full_name(), |
| "' cannot have null values."); |
| } else { |
| self.GetPolicy().SetOptionalPolicy(field_filter, domain.policy(), |
| /*is_recursive=*/false); |
| } |
| } |
| |
| #define FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Camel, cpp, TAG) \ |
| if constexpr (std::is_same_v<T, TAG>) { \ |
| self.GetPolicy().SetDefaultDomainFor##Camel##s( \ |
| field_filter, domain.Inner(), /*is_recursive=*/false); \ |
| self.GetPolicy().SetDomainTransformerFor##Camel##s( \ |
| field_filter, Identity<cpp>(), /*is_recursive=*/false); \ |
| } |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Bool, bool, bool) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Int32, int32_t, int32_t) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(UInt32, uint32_t, uint32_t) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Int64, int64_t, int64_t) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(UInt64, uint64_t, uint64_t) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Float, float, float) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Double, double, double) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(String, std::string, |
| std::string) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD(Enum, int, ProtoEnumTag) |
| FUZZTEST_INTERNAL_SET_BASE_DOMAIN_FOR_FIELD( |
| Protobuf, std::unique_ptr<Message>, ProtoMessageTag) |
| } |
| } |
| |
| template <typename T> |
| void VisitSingular(const FieldDescriptor* field) { |
| using DomainT = decltype(self.GetDefaultDomainForField<T, false>(field)); |
| ApplyDomain<T, DomainT, /*is_repeated=*/false>(field); |
| } |
| |
| template <typename T> |
| void VisitRepeated(const FieldDescriptor* field) { |
| using DomainT = decltype(self.GetDefaultDomainForField<T, true>(field)); |
| ApplyDomain<T, DomainT, /*is_repeated=*/true>(field); |
| } |
| }; |
| |
| template <typename Inner> |
| void WithField(std::string_view field_name, Inner&& domain) { |
| auto* field = |
| prototype_->GetDescriptor()->FindFieldByName(std::string(field_name)); |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION(field != nullptr, |
| "Invalid field name '", |
| std::string(field_name), "'."); |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION( |
| !field->containing_oneof(), |
| "Customizing the domain for oneof fields is not supported yet."); |
| VisitProtobufField( |
| field, WithFieldVisitor<Inner&&>{std::forward<Inner>(domain), *this}); |
| are_fields_customized_ = true; |
| } |
| |
| void SetPolicy(ProtoPolicy<Message> policy) { policy_ = policy; } |
| |
| ProtoPolicy<Message>& GetPolicy() { return policy_; } |
| |
| template <typename T> |
| auto GetFieldTypeDefaultDomain(absl::string_view field_name) const { |
| auto* field = |
| prototype_->GetDescriptor()->FindFieldByName(std::string(field_name)); |
| FUZZTEST_INTERNAL_CHECK_PRECONDITION(field != nullptr, |
| "Invalid field name '", |
| std::string(field_name), "'."); |
| return GetBaseDomainForFieldType<T>(field, /*use_policy=*/true); |
| } |
| |
| private: |
| // Get the existing domain for `field`, if exists. |
| // Otherwise, create the appropriate `Arbitrary<>` domain for the field and |
| // return it. |
| template <typename T, bool is_repeated> |
| auto& GetSubDomain(const FieldDescriptor* field) const { |
| // Do the operation under a lock to prevent race conditions in `const` |
| // methods. |
| absl::MutexLock l(&mutex_); |
| auto it = domains_.find(field->number()); |
| if (it == domains_.end()) { |
| it = domains_ |
| .try_emplace(field->number(), std::in_place, |
| GetDomainForField<T, is_repeated>(field)) |
| .first; |
| } |
| using DomainT = decltype(GetDefaultDomainForField<T, is_repeated>(field)); |
| return it->second.template GetAs<DomainT>(); |
| } |
| |
| // Simple wrapper that converts a Domain<T> into a Domain<vector<T>>. |
| template <typename T> |
| static Domain<std::vector<T>> ModifyDomainForRepeatedFieldRule( |
| const Domain<T>& d, std::optional<int64_t> min_size, |
| std::optional<int64_t> max_size) { |
| auto domain = ContainerOfImpl<std::vector<T>, Domain<T>>(d); |
| if (min_size.has_value()) { |
| domain.WithMinSize(*min_size); |
| } |
| if (max_size.has_value()) { |
| domain.WithMaxSize(*max_size); |
| } |
| return domain; |
| } |
| |
| template <typename T> |
| static Domain<std::optional<T>> ModifyDomainForOptionalFieldRule( |
| const Domain<T>& d, OptionalPolicy optional_policy) { |
| auto result = OptionalOfImpl<std::optional<T>, Domain<T>>(d); |
| if (optional_policy == OptionalPolicy::kWithoutNull) { |
| result.SetWithoutNull(); |
| } else if (optional_policy == OptionalPolicy::kAlwaysNull) { |
| result.SetAlwaysNull(); |
| } |
| return result; |
| } |
| |
| template <typename T> |
| static Domain<std::optional<T>> ModifyDomainForRequiredFieldRule( |
| const Domain<T>& d) { |
| return OptionalOfImpl<std::optional<T>, Domain<T>>(d).SetWithoutNull(); |
| } |
| |
| // Returns the default "base domain" for a `field` solely based on its type |
| // (i.e., int32/string), but ignoring the field rule (i.e., the |
| // repeated/optional/required specifiers). |
| template <typename T> |
| auto GetBaseDefaultDomainForFieldType(const FieldDescriptor* field) const { |
| if constexpr (std::is_same_v<T, std::string>) { |
| if (field->type() == FieldDescriptor::TYPE_STRING) { |
| // Can only use UTF-8. For now, simplify as just ASCII. |
| return Domain<T>(ContainerOfImpl<std::string, InRangeImpl<char>>( |
| InRangeImpl<char>(char{0}, char{127}))); |
| } |
| } |
| |
| if constexpr (std::is_same_v<T, ProtoEnumTag>) { |
| // For enums, build the list of valid labels. |
| auto* e = field->enum_type(); |
| std::vector<int> values; |
| for (int i = 0; i < e->value_count(); ++i) { |
| values.push_back(e->value(i)->number()); |
| } |
| // Delay instantiation. The Domain class is not fully defined at this |
| // point yet, and neither is ElementOfImpl. |
| using LazyInt = MakeDependentType<int, T>; |
| return Domain<LazyInt>(ElementOfImpl<LazyInt>(std::move(values))); |
| } else if constexpr (std::is_same_v<T, ProtoMessageTag>) { |
| auto result = ProtobufDomainUntypedImpl( |
| prototype_->GetReflection()->GetMessageFactory()->GetPrototype( |
| field->message_type())); |
| result.SetPolicy(policy_); |
| return Domain<std::unique_ptr<Message>>(result); |
| } else { |
| return Domain<T>(ArbitraryImpl<T>()); |
| } |
| } |
| |
| template <typename T> |
| auto GetBaseDomainForFieldType(const FieldDescriptor* field, |
| bool use_policy) const { |
| auto default_domain = GetBaseDefaultDomainForFieldType<T>(field); |
| if (!use_policy) { |
| return default_domain; |
| } |
| #define FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Camel, type) \ |
| if constexpr (std::is_same_v<T, type>) { \ |
| auto default_domain_in_policy = \ |
| policy_.GetDefaultDomainFor##Camel##s(field); \ |
| if (default_domain_in_policy.has_value()) { \ |
| default_domain = std::move(*default_domain_in_policy); \ |
| } \ |
| auto transformer_in_policy = \ |
| policy_.GetDomainTransformerFor##Camel##s(field); \ |
| if (transformer_in_policy.has_value()) { \ |
| default_domain = (*transformer_in_policy)(std::move(default_domain)); \ |
| } \ |
| return default_domain; \ |
| } |
| |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Bool, bool) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int32, int32_t) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt32, uint32_t) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Int64, int64_t) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(UInt64, uint64_t) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Float, float) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Double, double) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(String, std::string) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Enum, ProtoEnumTag) |
| FUZZTEST_INTERNAL_RETURN_BASE_DOMAIN_IF_PROVIDED(Protobuf, ProtoMessageTag) |
| |
| return GetBaseDefaultDomainForFieldType<T>(field); |
| } |
| |
| template <typename T, bool is_repeated> |
| auto GetDomainForField(const FieldDescriptor* field, |
| bool use_policy = true) const { |
| auto domain = GetBaseDomainForFieldType<T>(field, use_policy); |
| if constexpr (is_repeated) { |
| return ModifyDomainForRepeatedFieldRule( |
| std::move(domain), |
| use_policy ? policy_.GetMinRepeatedFieldSize(field) : std::nullopt, |
| use_policy ? policy_.GetMaxRepeatedFieldSize(field) : std::nullopt); |
| } else if (field->is_required()) { |
| return ModifyDomainForRequiredFieldRule(std::move(domain)); |
| } else { |
| return ModifyDomainForOptionalFieldRule( |
| std::move(domain), use_policy ? policy_.GetOptionalPolicy(field) |
| : OptionalPolicy::kWithNull); |
| } |
| } |
| |
| template <typename T, bool is_repeated> |
| auto GetDefaultDomainForField(const FieldDescriptor* field) const { |
| return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false); |
| } |
| |
| bool IsNonTerminatingRecursive() { |
| absl::flat_hash_set<decltype(prototype_->GetDescriptor())> parents; |
| return IsProtoRecursive(prototype_->GetDescriptor(), parents, |
| policy_.GetChildrenPolicy(), |
| /*consider_non_terminating_recursions=*/true); |
| } |
| |
| bool IsFieldRecursive(const FieldDescriptor* field) { |
| if (!field->message_type()) return false; |
| absl::flat_hash_set<decltype(field->message_type())> parents; |
| return IsProtoRecursive(field->message_type(), parents, |
| policy_.GetChildrenPolicy(), |
| /*consider_non_terminating_recursions=*/false); |
| } |
| |
| template <typename Descriptor> |
| static bool IsProtoRecursive(const Descriptor* descriptor, |
| absl::flat_hash_set<const Descriptor*>& parents, |
| const ProtoPolicy<Message>& policy, |
| bool consider_non_terminating_recursions) { |
| if (parents.contains(descriptor)) return true; |
| parents.insert(descriptor); |
| for (int i = 0; i < descriptor->field_count(); ++i) { |
| const auto* field = descriptor->field(i); |
| const auto* child = field->message_type(); |
| if (!child) continue; |
| if (consider_non_terminating_recursions) { |
| const bool should_be_set = |
| field->is_required() || |
| (field->is_optional() && |
| policy.GetOptionalPolicy(field) == OptionalPolicy::kWithoutNull) || |
| (field->is_repeated() && |
| policy.GetMinRepeatedFieldSize(field).has_value() && |
| *policy.GetMinRepeatedFieldSize(field) > 0); |
| if (!should_be_set) continue; |
| } else { |
| const bool can_be_set = |
| field->is_required() || |
| (field->is_optional() && |
| policy.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) || |
| (field->is_repeated() && |
| (!policy.GetMaxRepeatedFieldSize(field).has_value() || |
| *policy.GetMaxRepeatedFieldSize(field) > 0)); |
| if (!can_be_set) continue; |
| } |
| if (IsProtoRecursive(child, parents, policy, |
| consider_non_terminating_recursions)) { |
| return true; |
| } |
| } |
| parents.erase(descriptor); |
| return false; |
| } |
| |
| const Message* prototype_; |
| |
| mutable absl::Mutex mutex_; |
| mutable absl::flat_hash_map<int, PolymorphicValue<>> domains_ |
| ABSL_GUARDED_BY(mutex_); |
| |
| ProtoPolicy<Message> policy_; |
| bool are_fields_customized_; |
| }; |
| |
| // Domain for `T` where `T` is a Protobuf message type. |
| // It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more |
| // convenient. |
| template <typename T> |
| class ProtobufDomainImpl : public DomainBase<ProtobufDomainImpl<T>> { |
| using Inner = ProtobufDomainUntypedImpl<typename T::Message>; |
| |
| public: |
| using value_type = T; |
| using corpus_type = typename Inner::corpus_type; |
| using FieldDescriptor = ProtobufFieldDescriptor<typename T::Message>; |
| static constexpr bool has_custom_corpus_type = true; |
| |
| template <typename PRNG> |
| corpus_type Init(PRNG& prng) { |
| return inner_.Init(prng); |
| } |
| |
| uint64_t CountNumberOfFields(corpus_type& val) { |
| return inner_.CountNumberOfFields(val); |
| } |
| |
| uint64_t MutateNumberOfProtoFields(corpus_type& val) { |
| return inner_.MutateNumberOfProtoFields(val); |
| } |
| |
| template <typename PRNG> |
| void Mutate(corpus_type& val, PRNG& prng, bool only_shrink) { |
| inner_.Mutate(val, prng, only_shrink); |
| } |
| |
| value_type GetValue(const corpus_type& v) const { |
| auto inner_v = inner_.GetValue(v); |
| return std::move(static_cast<T&>(*inner_v)); |
| } |
| |
| std::optional<corpus_type> FromValue(const value_type& value) const { |
| return inner_.FromValue(value); |
| } |
| |
| auto GetPrinter() const { return ProtobufPrinter{}; } |
| |
| std::optional<corpus_type> ParseCorpus(const IRObject& obj) const { |
| return inner_.ParseCorpus(obj); |
| } |
| |
| IRObject SerializeCorpus(const corpus_type& v) const { |
| return inner_.SerializeCorpus(v); |
| } |
| |
| // Provide a conversion to the type that WithMessageField wants. |
| // Makes it easier on the user. |
| operator Domain<std::unique_ptr<typename T::Message>>() const { |
| return inner_; |
| } |
| |
| ProtobufDomainImpl&& Self() && { return std::move(*this); } |
| |
| ProtobufDomainImpl&& WithOptionalFieldsAlwaysSet( |
| std::function<bool(const FieldDescriptor*)> filter = |
| IncludeAll<FieldDescriptor>()) && { |
| inner_.GetPolicy().SetOptionalPolicy(std::move(filter), |
| OptionalPolicy::kWithoutNull); |
| return std::move(*this); |
| } |
| |
| ProtobufDomainImpl&& WithOptionalFieldsUnset( |
| std::function<bool(const FieldDescriptor*)> filter = |
| IncludeAll<FieldDescriptor>()) && { |
| inner_.GetPolicy().SetOptionalPolicy(std::move(filter), |
| OptionalPolicy::kAlwaysNull); |
| return std::move(*this); |
| } |
| |
| ProtobufDomainImpl&& WithRepeatedFieldsMinSize(int64_t min_size) && { |
| inner_.GetPolicy().SetMinRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), |
| min_size); |
| return std::move(*this); |
| } |
| |
| ProtobufDomainImpl&& WithRepeatedFieldsMinSize( |
| std::function<bool(const FieldDescriptor*)> filter, int64_t min_size) && { |
| inner_.GetPolicy().SetMinRepeatedFieldsSize(std::move(filter), min_size); |
| return std::move(*this); |
| } |
| |
| ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(int64_t max_size) && { |
| inner_.GetPolicy().SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), |
| max_size); |
| return std::move(*this); |
| } |
| |
| ProtobufDomainImpl&& WithRepeatedFieldsMaxSize( |
| std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && { |
| inner_.GetPolicy().SetMaxRepeatedFieldsSize(std::move(filter), max_size); |
| return std::move(*this); |
| } |
| |
| #define FUZZTEST_INTERNAL_WITH_FIELD(Camel, cpp, TAG) \ |
| using Camel##type = MakeDependentType<cpp, T>; \ |
| ProtobufDomainImpl&& With##Camel##Field(std::string_view field, \ |
| Domain<Camel##type> domain)&& { \ |
| inner_.WithField( \ |
| field, OptionalOfImpl<std::optional<Camel##type>, decltype(domain)>( \ |
| std::move(domain))); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##FieldUnset(std::string_view field)&& { \ |
| auto default_domain = \ |
| inner_.template GetFieldTypeDefaultDomain<TAG>(field); \ |
| inner_.WithField( \ |
| field, \ |
| OptionalOfImpl<std::optional<Camel##type>, decltype(default_domain)>( \ |
| std::move(default_domain)) \ |
| .SetAlwaysNull()); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##FieldAlwaysSet( \ |
| std::string_view field, Domain<Camel##type> domain)&& { \ |
| inner_.WithField( \ |
| field, OptionalOfImpl<std::optional<Camel##type>, decltype(domain)>( \ |
| std::move(domain)) \ |
| .SetWithoutNull()); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##FieldAlwaysSet(std::string_view field)&& { \ |
| return std::move(*this).With##Camel##FieldAlwaysSet( \ |
| field, inner_.template GetFieldTypeDefaultDomain<TAG>(field)); \ |
| } \ |
| template <template <typename, typename> typename DomainImpl, typename Inner> \ |
| ProtobufDomainImpl&& WithRepeated##Camel##Field( \ |
| std::string_view field, \ |
| DomainImpl<MakeDependentType<std::vector<cpp>, T>, Inner> domain)&& { \ |
| inner_.WithField(field, std::move(domain)); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##Fields(Domain<Camel##type> domain)&& { \ |
| inner_.GetPolicy().SetDefaultDomainFor##Camel##s( \ |
| IncludeAll<FieldDescriptor>(), std::move(domain)); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##Fields( \ |
| std::function<bool(const FieldDescriptor*)>&& filter, \ |
| Domain<Camel##type> domain)&& { \ |
| inner_.GetPolicy().SetDefaultDomainFor##Camel##s(std::move(filter), \ |
| std::move(domain)); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##FieldsTransformed( \ |
| std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \ |
| transformer)&& { \ |
| inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \ |
| IncludeAll<FieldDescriptor>(), std::move(transformer)); \ |
| return std::move(*this); \ |
| } \ |
| ProtobufDomainImpl&& With##Camel##FieldsTransformed( \ |
| std::function<bool(const FieldDescriptor*)>&& filter, \ |
| std::function<Domain<Camel##type>(Domain<Camel##type>)>&& \ |
| transformer)&& { \ |
| inner_.GetPolicy().SetDomainTransformerFor##Camel##s( \ |
| std::move(filter), std::move(transformer)); \ |
| return std::move(*this); \ |
| } |
| |
| FUZZTEST_INTERNAL_WITH_FIELD(Bool, bool, bool) |
| FUZZTEST_INTERNAL_WITH_FIELD(Int32, int32_t, int32_t) |
| FUZZTEST_INTERNAL_WITH_FIELD(UInt32, uint32_t, uint32_t) |
| FUZZTEST_INTERNAL_WITH_FIELD(Int64, int64_t, int64_t) |
| FUZZTEST_INTERNAL_WITH_FIELD(UInt64, uint64_t, uint64_t) |
| FUZZTEST_INTERNAL_WITH_FIELD(Float, float, float) |
| FUZZTEST_INTERNAL_WITH_FIELD(Double, double, double) |
| FUZZTEST_INTERNAL_WITH_FIELD(String, std::string, std::string) |
| FUZZTEST_INTERNAL_WITH_FIELD(Enum, int, ProtoEnumTag) |
| FUZZTEST_INTERNAL_WITH_FIELD(Protobuf, std::unique_ptr<typename T::Message>, |
| ProtoMessageTag) |
| |
| template <typename ProtoDomain> |
| ProtobufDomainImpl&& WithRepeatedProtobufField( |
| std::string_view field, |
| SequenceContainerOfImpl<std::vector<typename ProtoDomain::value_type>, |
| ProtoDomain> |
| domain) && { |
| auto cast_domain = SequenceContainerOfImpl< |
| std::vector<std::unique_ptr<typename T::Message>>, |
| Domain<std::unique_ptr<typename T::Message>>>(domain.Inner()); |
| cast_domain.CopyConstraintsFrom(domain); |
| inner_.WithField(field, cast_domain); |
| return std::move(*this); |
| } |
| |
| #undef FUZZTEST_INTERNAL_WITH_FIELD |
| |
| private: |
| Inner inner_{&T::default_instance()}; |
| }; |
| |
| template <typename T> |
| class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_v<T>>> |
| : public ProtobufDomainImpl<T> {}; |
| |
| template <typename T> |
| class ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_enum_v<T>>> |
| : public DomainBase<ArbitraryImpl<T>> { |
| public: |
| using value_type = T; |
| |
| template <typename PRNG> |
| value_type Init(PRNG& prng) { |
| const int index = absl::Uniform(prng, 0, descriptor()->value_count()); |
| return static_cast<T>(descriptor()->value(index)->number()); |
| } |
| |
| template <typename PRNG> |
| void Mutate(value_type& val, PRNG& prng, bool only_shrink) { |
| if (only_shrink) { |
| std::vector<int> numbers; |
| for (int i = 0; i < descriptor()->value_count(); ++i) { |
| if (int n = descriptor()->value(i)->number(); n < val) |
| numbers.push_back(n); |
| } |
| if (numbers.empty()) return; |
| size_t idx = absl::Uniform<size_t>(prng, 0, numbers.size()); |
| val = static_cast<T>(numbers[idx]); |
| return; |
| } else if (descriptor()->value_count() == 1) { |
| return; |
| } |
| |
| // Make sure Mutate really mutates. |
| const T prev = val; |
| do { |
| val = Init(prng); |
| } while (val == prev); |
| } |
| |
| auto GetPrinter() const { |
| return ProtobufEnumPrinter<decltype(descriptor())>{descriptor()}; |
| } |
| |
| private: |
| auto descriptor() const { |
| static auto const descriptor_ = google::protobuf::GetEnumDescriptor<T>(); |
| return descriptor_; |
| } |
| }; |
| |
| } // namespace fuzztest::internal |
| |
| #endif // FUZZTEST_FUZZTEST_INTERNAL_PROTOBUF_DOMAIN_H_ |