blob: 12068d2d4aa91f1826f0700f0d1c851d1b6ea25a [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FUZZTEST_INTERNAL_DOMAIN_TESTING_H_
#define FUZZTEST_INTERNAL_DOMAIN_TESTING_H_
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <limits>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/util/field_comparator.h"
#include "google/protobuf/util/message_differencer.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/random/random.h"
#include "absl/strings/str_cat.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/test_protobuf.pb.h"
#include "./fuzztest/internal/type_support.h"
namespace fuzztest {
// Tests whether arg is in the range [a, b].
MATCHER_P2(IsInClosedRange, a, b,
absl::StrCat(negation ? "isn't" : "is", " in the closed range [",
testing::PrintToString(a), ", ",
testing::PrintToString(b), "]")) {
return a <= arg && arg <= b;
}
// Make sure floating hash/eq handle NaN.
struct Hash {
template <typename T>
size_t operator()(const T& v) const {
if constexpr (internal::Requires<T>(
[](auto v) -> decltype(v && v.get()) {})) {
// Smart pointers.
return v ? absl::HashOf(*v) : 0;
} else if constexpr (internal::Requires<T>(
[](auto v) -> decltype(std::isnan(
*std::optional(v))) {})) {
auto o = std::optional(v);
return !o || std::isnan(*o) ? 0 : absl::Hash<T>{}(*o);
} else if constexpr (internal::Requires<T>(
[](auto v) -> decltype(v.hash_function()) {})) {
return (*this)(std::set<typename T::value_type>(v.begin(), v.end()));
} else {
return absl::Hash<T>{}(v);
}
}
};
struct Eq {
template <typename T>
bool operator()(const T& a, const T& b) const {
if constexpr (internal::is_protocol_buffer_v<T>) {
google::protobuf::util::MessageDifferencer differencer;
google::protobuf::util::DefaultFieldComparator cmp;
cmp.set_treat_nan_as_equal(true);
differencer.set_field_comparator(&cmp);
return differencer.Compare(a, b);
} else if constexpr (internal::Requires<T>(
[](auto v) -> decltype(v && v.get()) {})) {
// Smart pointers.
return a ? b && *a == *b : !b;
} else if constexpr (internal::Requires<T>(
[](auto v) -> decltype(std::isnan(
*std::optional(v))) {})) {
auto oa = std::optional(a), ob = std::optional(b);
return a == b || (oa && ob && std::isnan(*oa) && std::isnan(*ob));
} else {
return a == b;
}
}
};
template <typename T>
using Set = absl::flat_hash_set<T, Hash, Eq>;
// The Value class keeps the corpus and value types together throughout tests to
// simplify their access and mutation.
template <typename Domain>
struct Value {
using T = typename Domain::value_type;
internal::corpus_type_t<Domain> corpus_value;
T user_value;
Value(Domain& domain, absl::BitGen& bitgen)
: corpus_value(domain.Init(bitgen)),
user_value(domain.GetValue(corpus_value)) {}
// If the value_type is not copy constructible we have to copy the corpus and
// regenerate the value.
Value(const Value& other, Domain& domain)
: corpus_value(other.corpus_value),
user_value(domain.GetValue(corpus_value)) {}
void Mutate(Domain& domain, absl::BitGen& bitgen, bool only_shrink) {
domain.Mutate(corpus_value, bitgen, only_shrink);
user_value = domain.GetValue(corpus_value);
}
// Make the Value hashable/comparable to put them in sets/maps.
template <typename H>
friend H AbslHashValue(H state, const Value& self) {
const auto& v = self.user_value;
if constexpr (internal::Requires<T>(
[](auto v) -> decltype(std::isnan(*std::optional(v))) {
})) {
auto o = std::optional(v);
if (o && std::isnan(*o)) o = 0;
return H::combine(std::move(state), o);
} else if constexpr (internal::Requires<T>(
[](auto v) -> decltype(v.hash_function()) {})) {
return H::combine(std::move(state),
std::set<typename T::value_type>(v.begin(), v.end()));
} else {
return H::combine(std::move(state), v);
}
}
friend bool operator==(const Value& a, const Value& b) {
if constexpr (internal::Requires<T>(
[](auto v) -> decltype(std::isnan(*std::optional(v))) {
})) {
auto oa = std::optional(a.user_value), ob = std::optional(b.user_value);
return a.user_value == b.user_value ||
(oa && ob && std::isnan(*oa) && std::isnan(*ob));
} else {
return a.user_value == b.user_value;
}
}
friend bool operator!=(const Value& a, const Value& b) { return !(a == b); }
friend bool operator==(const Value& a, const T& b) {
return a.user_value == b;
}
friend bool operator!=(const Value& a, const T& b) {
return a.user_value != b;
}
friend bool operator<(const Value& a, const T& b) { return a.user_value < b; }
friend bool operator>(const Value& a, const T& b) { return a.user_value > b; }
friend bool operator<=(const Value& a, const T& b) {
return a.user_value <= b;
}
friend bool operator>=(const Value& a, const T& b) {
return a.user_value >= b;
}
friend std::ostream& operator<<(std::ostream& s, const Value& v) {
return s << testing::PrintToString(v.user_value);
}
private:
// We don't test the printers here, just that we return one.
// The printers themselves are tested in type_support_test.cc
using Printer = decltype(std::declval<const Domain&>().GetPrinter());
};
template <typename Domain>
void VerifyRoundTripThroughConversion(const Value<Domain>& v,
const Domain& domain) {
{
auto corpus_value = domain.FromValue(v.user_value);
ASSERT_TRUE(corpus_value) << v;
auto new_v = domain.GetValue(*corpus_value);
EXPECT_TRUE(Eq{}(v.user_value, new_v))
<< "v=" << v << " new_v=" << testing::PrintToString(new_v);
}
{
auto serialized = domain.SerializeCorpus(v.corpus_value).ToString();
auto parsed = internal::IRObject::FromString(serialized);
ASSERT_TRUE(parsed);
auto parsed_corpus = domain.ParseCorpus(*parsed);
ASSERT_TRUE(parsed_corpus)
<< serialized << " value = " << testing::PrintToString(v.user_value);
EXPECT_TRUE(Eq{}(v.user_value, domain.GetValue(*parsed_corpus)));
}
}
template <typename Container, typename Domain>
void VerifyRoundTripThroughConversion(const Container& values,
const Domain& domain) {
for (const auto& v : values) {
VerifyRoundTripThroughConversion(v, domain);
}
}
template <typename Domain>
Value(Domain&, absl::BitGen&) -> Value<Domain>;
template <typename Domain>
auto GenerateValues(Domain domain, int num_seeds = 10,
int num_mutations = 100) {
absl::BitGen bitgen;
absl::flat_hash_set<Value<Domain>> seeds;
// Make sure we can make some unique seeds.
// Randomness might create duplicates so keep going until we got them.
while (seeds.size() < num_seeds) {
seeds.insert(Value(domain, bitgen));
}
auto values = seeds;
for (const auto& seed : seeds) {
auto value = seed;
absl::flat_hash_set<Value<Domain>> mutations = {value};
// As above, we repeat until we find enough unique ones.
while (mutations.size() < num_mutations) {
const auto previous = value;
value.Mutate(domain, bitgen, false);
// Make sure that it changed in some way.
mutations.insert(value);
EXPECT_NE(previous, value) << "Value=" << value << " Prev=" << previous;
}
values.merge(mutations);
}
return values;
}
template <typename Values, typename Pred>
void CheckValues(const Values& values, Pred pred) {
for (const auto& value : values) {
ASSERT_TRUE(pred(value.user_value)) << "Incorrect value: " << value;
}
}
template <typename Domain>
auto MutateUntilFoundN(Domain domain, size_t n) {
absl::flat_hash_set<Value<Domain>> seen;
absl::BitGen bitgen;
Value val(domain, bitgen);
while (seen.size() < n) {
seen.insert(Value(val, domain));
val.Mutate(domain, bitgen, false);
}
return seen;
}
template <typename Domain, typename IsTerminal, typename IsCloser,
typename T = typename Domain::value_type>
void TestShrink(Domain domain, const absl::flat_hash_set<Value<Domain>>& values,
IsTerminal is_terminal, IsCloser is_closer_to_zero) {
absl::BitGen bitgen;
for (auto value : values) {
while (!is_terminal(value.user_value)) {
auto previous_value = value;
value.Mutate(domain, bitgen, true);
ASSERT_NE(value, previous_value);
if (!is_terminal(value.user_value)) {
ASSERT_TRUE(
is_closer_to_zero(previous_value.user_value, value.user_value))
<< testing::PrintToString(previous_value.user_value) << " "
<< testing::PrintToString(value.user_value);
}
}
}
}
template <typename Domain, typename Values, typename Pred>
void TestShrink(Domain domain, const Values& values, Pred pred) {
absl::BitGen bitgen;
for (const auto& value : values) {
auto other_value = value;
other_value.Mutate(domain, bitgen, true);
ASSERT_TRUE(pred(value.user_value, other_value.user_value))
<< value << " " << other_value;
}
}
template <typename T>
bool AreSame(const T& prev, const T& next) {
if constexpr (std::is_arithmetic_v<T>) {
return prev == next || (std::isnan(prev) && std::isnan(next));
} else if constexpr (std::is_enum_v<T> ||
std::is_same_v<std::optional<int>, T> ||
std::is_same_v<std::pair<int, int>, T> ||
std::is_same_v<std::pair<const int, int>, T>) {
return prev == next;
} else {
// For container types.
if (prev.size() != next.size()) return false;
for (auto prev_i = prev.begin(), next_i = next.begin();
prev_i != prev.end(); ++prev_i, ++next_i) {
if (!AreSame(*prev_i, *next_i)) return false;
}
return true;
}
}
template <typename T>
bool TowardsZero(const T& prev, const T& next) {
if constexpr (std::numeric_limits<T>::is_integer || std::is_arithmetic_v<T> ||
std::is_enum_v<T>) {
if constexpr (std::is_floating_point_v<T>) {
// Ignore non-finites. Those never move towards zero.
if (!std::isfinite(prev) || !std::isfinite(next)) return true;
}
return (prev == next && prev == 0) ||
(prev < 0 && prev < next && next <= 0) ||
(prev >= 0 && 0 <= next && next < prev);
} else if constexpr (std::is_same_v<std::optional<int>, T>) {
return (prev && next && TowardsZero(*prev, *next)) || (prev && !next) ||
(!prev && !next);
} else if constexpr (std::is_same_v<std::pair<int, int>, T> ||
std::is_same_v<std::pair<const int, int>, T> ||
std::is_same_v<std::pair<const std::string, int>, T>) {
return TowardsZero(prev.first, next.first) ||
TowardsZero(prev.second, next.second);
} else {
// For container types.
if (prev.empty() && next.empty()) return true;
if (prev.size() != next.size()) return prev.size() > next.size();
// Something inside shrunk.
if constexpr (internal::Requires<T>([](auto v) ->
typename decltype(v)::key_type {})) {
const auto get_key = [](auto it) -> decltype(auto) {
if constexpr (std::is_same_v<typename T::key_type,
typename T::value_type>) {
return *it;
} else {
return it->first;
}
};
// There can be at most one element whose key doesn't match. Compare that
// one separately after the loop.
std::optional<typename T::const_iterator> missing_prev, missing_next;
// For associative containers, do a find.
// Sequential iteration will not work for unordered ones.
for (auto it = prev.begin(); it != prev.end(); ++it) {
auto rhs = next.find(get_key(it));
if (rhs != next.end()) {
if (!TowardsZero(*it, *rhs) && !Eq{}(*it, *rhs)) return false;
} else {
if (missing_prev.has_value()) return false;
missing_prev = it;
}
}
for (auto it = next.begin(); it != next.end(); ++it) {
if (prev.find(get_key(it)) == prev.end()) {
if (missing_next.has_value()) return false;
missing_next = it;
}
}
if (missing_prev.has_value() != missing_next.has_value()) return false;
if (missing_prev.has_value()) {
return TowardsZero(**missing_prev, **missing_next);
}
return true;
} else {
for (auto prev_i = prev.begin(), next_i = next.begin();
prev_i != prev.end(); ++prev_i, ++next_i) {
if (TowardsZero(*prev_i, *next_i)) return true;
}
return false;
}
}
}
template <typename ScalarVisitor, typename RepeatedVisitor>
void VisitTestProtobuf(ScalarVisitor scalar_visitor,
RepeatedVisitor repeated_visitor) {
scalar_visitor(
"b", [](auto& val) { return val.has_b(); },
[](auto& val) { return val.b(); });
scalar_visitor(
"i32", [](auto& val) { return val.has_i32(); },
[](auto& val) { return val.i32(); });
scalar_visitor(
"u32", [](auto& val) { return val.has_u32(); },
[](auto& val) { return val.u32(); });
scalar_visitor(
"i64", [](auto& val) { return val.has_i64(); },
[](auto& val) { return val.i64(); });
scalar_visitor(
"u64", [](auto& val) { return val.has_u64(); },
[](auto& val) { return val.u64(); });
scalar_visitor(
"f", [](auto& val) { return val.has_f(); },
[](auto& val) { return val.f(); });
scalar_visitor(
"d", [](auto& val) { return val.has_d(); },
[](auto& val) { return val.d(); });
scalar_visitor(
"str", [](auto& val) { return val.has_str(); },
[](auto& val) { return val.str(); });
scalar_visitor(
"oneof_i32", [](auto& val) { return val.has_oneof_i32(); },
[](auto& val) { return val.oneof_i32(); });
scalar_visitor(
"oneof_i64", [](auto& val) { return val.has_oneof_i64(); },
[](auto& val) { return val.oneof_i64(); });
scalar_visitor(
"subproto_i32",
[](auto& val) { return val.subproto().has_subproto_i32(); },
[](auto& val) { return val.subproto().subproto_i32(); });
scalar_visitor(
"e", [](auto& val) { return val.has_e(); },
[](auto& val) { return val.e(); });
repeated_visitor("rep_b", [](auto& val) { return val.rep_b(); });
repeated_visitor("rep_i32", [](auto& val) { return val.rep_i32(); });
repeated_visitor("rep_u32", [](auto& val) { return val.rep_u32(); });
repeated_visitor("rep_i64", [](auto& val) { return val.rep_i64(); });
repeated_visitor("rep_u64", [](auto& val) { return val.rep_u64(); });
repeated_visitor("rep_f", [](auto& val) { return val.rep_f(); });
repeated_visitor("rep_d", [](auto& val) { return val.rep_d(); });
repeated_visitor("rep_str", [](auto& val) { return val.rep_str(); });
repeated_visitor("map_field", [](auto& val) {
std::vector<std::pair<int, int>> v;
for (const auto& p : val.map_field()) v.push_back(p);
std::sort(v.begin(), v.end());
return v;
});
repeated_visitor("rep_subproto", [](auto& val) {
std::vector<std::optional<int>> v;
for (const auto& s : val.rep_subproto()) {
v.push_back(s.has_subproto_i32() ? std::optional(s.subproto_i32())
: std::nullopt);
}
return v;
});
repeated_visitor("rep_e", [](auto& val) { return val.rep_e(); });
}
inline bool TowardsZero(const internal::TestProtobuf& prev,
const internal::TestProtobuf& next) {
std::string_view error_field;
const auto verify_towards_zero = [&](std::string_view name, auto has,
auto get) {
auto pv = has(prev) ? std::optional(get(prev)) : std::nullopt;
auto nv = has(next) ? std::optional(get(next)) : std::nullopt;
if (pv && nv && !TowardsZero(*pv, *nv) && !AreSame(*pv, *nv)) {
error_field = name;
}
if (!pv && nv) {
error_field = name;
}
};
const auto verify_repeated_towards_zero = [&](std::string_view name,
auto get) {
auto pv = get(prev);
auto nv = get(next);
if (!TowardsZero(pv, nv) && !AreSame(pv, nv)) {
error_field = name;
}
};
VisitTestProtobuf(verify_towards_zero, verify_repeated_towards_zero);
if (error_field.empty()) {
return true;
} else {
ADD_FAILURE() << "Failed on field: " << error_field;
return false;
}
}
} // namespace fuzztest
#endif // FUZZTEST_INTERNAL_DOMAIN_TESTING_H_