Add union and vector of union support to the flatbuffers domain
PiperOrigin-RevId: 745521423
diff --git a/MODULE.bazel b/MODULE.bazel
index 9de1d24..043422f 100644
--- a/MODULE.bazel
+++ b/MODULE.bazel
@@ -42,6 +42,10 @@
name = "platforms",
version = "0.0.10",
)
+bazel_dep(
+ name = "flatbuffers",
+ version = "25.2.10"
+)
# GoogleTest is not a dev dependency, because it's needed when FuzzTest is used
# with GoogleTest integration (e.g., googletest_adaptor). Note that the FuzzTest
# framework can be used without GoogleTest integration as well.
@@ -55,8 +59,6 @@
name = "protobuf",
version = "30.2",
)
-# TODO(lszekeres): Make this a dev dependency, as the protobuf library is only
-# required for testing.
bazel_dep(
name = "rules_proto",
version = "7.1.0",
diff --git a/domain_tests/BUILD b/domain_tests/BUILD
index 436a2c9..4e207e3 100644
--- a/domain_tests/BUILD
+++ b/domain_tests/BUILD
@@ -34,6 +34,21 @@
)
cc_test(
+ name = "arbitrary_domains_flatbuffers_test",
+ srcs = ["arbitrary_domains_flatbuffers_test.cc"],
+ deps = [
+ ":domain_testing",
+ "@abseil-cpp//absl/random",
+ "@com_google_fuzztest//fuzztest:domain",
+ "@com_google_fuzztest//fuzztest:flatbuffers",
+ "@com_google_fuzztest//fuzztest:meta",
+ "@com_google_fuzztest//fuzztest:test_flatbuffers_cc_fbs",
+ "@flatbuffers//:runtime_cc",
+ "@googletest//:gtest_main",
+ ],
+)
+
+cc_test(
name = "arbitrary_domains_protobuf_test",
srcs = ["arbitrary_domains_protobuf_test.cc"],
deps = [
diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc
new file mode 100644
index 0000000..aca9ea5
--- /dev/null
+++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc
@@ -0,0 +1,526 @@
+#include <cstdint>
+#include <optional>
+#include <string_view>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/random/random.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/buffer.h"
+#include "flatbuffers/flatbuffer_builder.h"
+#include "flatbuffers/string.h"
+#include "flatbuffers/vector.h"
+#include "./fuzztest/domain.h"
+#include "./domain_tests/domain_testing.h"
+#include "./fuzztest/flatbuffers.h"
+#include "./fuzztest/internal/meta.h"
+#include "./fuzztest/test_flatbuffers_generated.h"
+
+namespace fuzztest {
+namespace {
+
+using ::fuzztest::internal::NestedTestFbsTable;
+using ::fuzztest::internal::OptionalRequiredTestFbsTable;
+using ::fuzztest::internal::SimpleTestFbsTable;
+using ::fuzztest::internal::TestFbsEnum;
+using ::fuzztest::internal::UnionTestFbsTable;
+using ::fuzztest::internal::VectorsTestFbsTable;
+using ::testing::Contains;
+using ::testing::IsTrue;
+using ::testing::ResultOf;
+
+TEST(FlatbuffersMetaTest, IsFlatbuffersTable) {
+ static_assert(internal::is_flatbuffers_table_v<SimpleTestFbsTable>);
+ static_assert(!internal::is_flatbuffers_table_v<int>);
+}
+
+TEST(FlatbuffersTableDomainImplTest, SimpleTestFbsTableValueRoundTrip) {
+ auto domain = Arbitrary<SimpleTestFbsTable>();
+
+ flatbuffers::FlatBufferBuilder fbb;
+ auto table_offset = internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second);
+ fbb.Finish(table_offset);
+ auto table = flatbuffers::GetRoot<SimpleTestFbsTable>(fbb.GetBufferPointer());
+
+ auto corpus = domain.FromValue(table);
+ ASSERT_TRUE(corpus.has_value());
+ ASSERT_OK(domain.ValidateCorpusValue(*corpus));
+
+ auto ir = domain.SerializeCorpus(corpus.value());
+
+ auto new_corpus = domain.ParseCorpus(ir);
+ ASSERT_TRUE(new_corpus.has_value());
+ ASSERT_OK(domain.ValidateCorpusValue(*new_corpus));
+
+ auto new_table = domain.GetValue(*new_corpus);
+ EXPECT_EQ(new_table->b(), true);
+ EXPECT_EQ(new_table->f(), 1.0);
+ EXPECT_EQ(new_table->str()->str(), "foo bar baz");
+ EXPECT_TRUE(new_table->e() == internal::TestFbsEnum_Second);
+}
+
+TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) {
+ auto domain = Arbitrary<SimpleTestFbsTable>();
+
+ flatbuffers::FlatBufferBuilder fbb;
+ auto table_offset = internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second);
+ fbb.Finish(table_offset);
+ auto table = flatbuffers::GetRoot<SimpleTestFbsTable>(fbb.GetBufferPointer());
+
+ domain.WithSeeds({table});
+
+ std::vector<Value<decltype(domain)>> values;
+ absl::BitGen bitgen;
+ values.reserve(1000);
+ for (int i = 0; i < 1000; ++i) {
+ Value value(domain, bitgen);
+ values.push_back(std::move(value));
+ }
+
+ EXPECT_THAT(
+ values,
+ Contains(ResultOf(
+ [table](const auto& val) {
+ bool has_same_str =
+ val.user_value->str() == nullptr && table->str() == nullptr;
+ if (val.user_value->str() != nullptr && table->str() != nullptr) {
+ has_same_str =
+ val.user_value->str()->str() == table->str()->str();
+ }
+ return (val.user_value->b() == table->b() &&
+ val.user_value->f() == table->f() &&
+ val.user_value->e() == table->e() && has_same_str);
+ },
+ IsTrue())));
+}
+
+TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllTableFields) {
+ auto domain = Arbitrary<SimpleTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_changes = [&](std::string_view name, auto get) {
+ Set<decltype(get(val.user_value))> values;
+
+ int iterations = 10'000;
+ while (--iterations > 0 && values.size() < 2) {
+ values.insert(get(val.user_value));
+ val.Mutate(domain, bitgen, {}, false);
+ }
+ EXPECT_GT(iterations, 0)
+ << "Field: " << name << " -- " << testing::PrintToString(values);
+ };
+
+ verify_field_changes("b", [](auto v) { return v->b(); });
+ verify_field_changes("f", [](auto v) { return v->f(); });
+ verify_field_changes("str",
+ [](auto v) { return v->str() ? v->str()->str() : ""; });
+ verify_field_changes("e", [](auto v) { return v->e(); });
+}
+
+TEST(FlatbuffersTableDomainImplTest, OptionalFieldsEventuallyBecomeEmpty) {
+ auto domain = Arbitrary<OptionalRequiredTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_becomes_null = [&](std::string_view name, auto has) {
+ for (int i = 0; i < 10'000; ++i) {
+ val.Mutate(domain, bitgen, {}, false);
+ if (!has(val.user_value)) {
+ break;
+ }
+ }
+ EXPECT_FALSE(has(val.user_value)) << "Field never became unset: " << name;
+ };
+
+ verify_field_becomes_null("opt_scalar",
+ [](auto v) { return v->opt_scalar().has_value(); });
+ verify_field_becomes_null("opt_str",
+ [](auto v) { return v->opt_str() != nullptr; });
+}
+
+TEST(FlatbuffersTableDomainImplTest, DefaultAndRequiredFieldsAlwaysSet) {
+ auto domain = Arbitrary<OptionalRequiredTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_always_set = [&](std::string_view name, auto has) {
+ for (int i = 0; i < 10'000; ++i) {
+ val.Mutate(domain, bitgen, {}, false);
+ if (!has(val.user_value)) {
+ break;
+ }
+ }
+ EXPECT_TRUE(has(val.user_value)) << "Field is not set: " << name;
+ };
+
+ verify_field_always_set("def_scalar", [](auto v) { return true; });
+ verify_field_always_set("req_str",
+ [](auto v) { return v->req_str() != nullptr; });
+}
+
+TEST(FlatbuffersTableDomainImplTest, NestedTableValueRoundTrip) {
+ auto domain = Arbitrary<NestedTestFbsTable>();
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ flatbuffers::FlatBufferBuilder fbb;
+ auto child_offset = internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second);
+ auto parent_offset = internal::CreateNestedTestFbsTable(fbb, child_offset);
+ fbb.Finish(parent_offset);
+ auto table = flatbuffers::GetRoot<NestedTestFbsTable>(fbb.GetBufferPointer());
+
+ auto parent_corpus = domain.FromValue(table);
+ ASSERT_TRUE(parent_corpus.has_value());
+
+ auto ir = domain.SerializeCorpus(parent_corpus.value());
+
+ auto new_corpus = domain.ParseCorpus(ir);
+ ASSERT_TRUE(new_corpus.has_value());
+ ASSERT_OK(domain.ValidateCorpusValue(*new_corpus));
+
+ auto new_table = domain.GetValue(parent_corpus.value());
+ EXPECT_NE(new_table->t(), nullptr);
+ EXPECT_EQ(new_table->t()->b(), true);
+ EXPECT_EQ(new_table->t()->f(), 1.0);
+ EXPECT_NE(new_table->t()->str(), nullptr);
+ EXPECT_EQ(new_table->t()->str()->str(), "foo bar baz");
+ EXPECT_TRUE(new_table->t()->e() == internal::TestFbsEnum_Second);
+}
+
+TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllNestedTableFields) {
+ auto domain = Arbitrary<NestedTestFbsTable>();
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_changes = [&](std::string_view name, auto get) {
+ Set<typename decltype(get(val.user_value))::value_type> values;
+
+ int iterations = 10'000;
+ while (--iterations > 0 && values.size() < 2) {
+ auto value = get(val.user_value);
+ if (value.has_value()) {
+ values.insert(*value);
+ }
+ val.Mutate(domain, bitgen, {}, false);
+ }
+ EXPECT_GT(iterations, 0)
+ << "Field: " << name << " -- " << testing::PrintToString(values);
+ };
+
+ verify_field_changes("t.b", [](auto v) {
+ return v->t() ? std::make_optional(v->t()->b()) : std::nullopt;
+ });
+ verify_field_changes("t.f", [](auto v) {
+ return v->t() ? std::make_optional(v->t()->f()) : std::nullopt;
+ });
+ verify_field_changes("t.str", [](auto v) {
+ return v->t() ? v->t()->str() ? std::make_optional(v->t()->str()->str())
+ : std::nullopt
+ : std::nullopt;
+ });
+ verify_field_changes("t.e", [](auto v) {
+ return v->t() ? std::make_optional(v->t()->e()) : std::nullopt;
+ });
+}
+
+TEST(FlatbuffersTableDomainImplTest, VectorsSerializeAndDeserialize) {
+ auto domain = Arbitrary<VectorsTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ flatbuffers::FlatBufferBuilder fbb;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> str_offsets;
+ for (const auto& str : {"foo", "bar", "baz"}) {
+ str_offsets.push_back(fbb.CreateString(str));
+ }
+ std::vector<flatbuffers::Offset<SimpleTestFbsTable>> table_offsets;
+ for (const auto& str : {"foo", "bar", "baz"}) {
+ table_offsets.push_back(internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, str, internal::TestFbsEnum_Second));
+ }
+ std::vector<uint8_t> b{true, false};
+ std::vector<int8_t> i8{1, 2, 3};
+ std::vector<int16_t> i16{1, 2, 3};
+ std::vector<int32_t> i32{1, 2, 3};
+ std::vector<int64_t> i64{1, 2, 3};
+ std::vector<uint8_t> u8{1, 2, 3};
+ std::vector<uint16_t> u16{1, 2, 3};
+ std::vector<uint32_t> u32{1, 2, 3};
+ std::vector<uint64_t> u64{1, 2, 3};
+ std::vector<float> f{1, 2, 3};
+ std::vector<double> d{1, 2, 3};
+ std::vector<std::underlying_type_t<TestFbsEnum>> e{
+ TestFbsEnum::TestFbsEnum_First, TestFbsEnum::TestFbsEnum_Second,
+ TestFbsEnum::TestFbsEnum_Third};
+ auto table_offset = internal::CreateVectorsTestFbsTableDirect(
+ fbb, &b, &i8, &i16, &i32, &i64, &u8, &u16, &u32, &u64, &f, &d,
+ &str_offsets, &e, &table_offsets);
+ fbb.Finish(table_offset);
+ auto table =
+ flatbuffers::GetRoot<VectorsTestFbsTable>(fbb.GetBufferPointer());
+
+ auto corpus = domain.FromValue(table);
+ auto ir = domain.SerializeCorpus(corpus.value());
+ {
+ auto new_corpus = domain.ParseCorpus(ir);
+ ASSERT_TRUE(new_corpus.has_value());
+ ASSERT_OK(domain.ValidateCorpusValue(*new_corpus));
+
+ auto new_table = domain.GetValue(*new_corpus);
+ ASSERT_NE(new_table, nullptr);
+ ASSERT_NE(new_table->b(), nullptr);
+ EXPECT_EQ(new_table->b()->size(), 2);
+ EXPECT_EQ(new_table->b()->Get(0), true);
+ EXPECT_EQ(new_table->b()->Get(1), false);
+ ASSERT_NE(new_table->i8(), nullptr);
+ EXPECT_EQ(new_table->i8()->size(), 3);
+ EXPECT_EQ(new_table->i8()->Get(0), 1);
+ EXPECT_EQ(new_table->i8()->Get(1), 2);
+ EXPECT_EQ(new_table->i8()->Get(2), 3);
+ ASSERT_NE(new_table->i16(), nullptr);
+ EXPECT_EQ(new_table->i16()->size(), 3);
+ EXPECT_EQ(new_table->i16()->Get(0), 1);
+ EXPECT_EQ(new_table->i16()->Get(1), 2);
+ EXPECT_EQ(new_table->i16()->Get(2), 3);
+ ASSERT_NE(new_table->i32(), nullptr);
+ EXPECT_EQ(new_table->i32()->size(), 3);
+ EXPECT_EQ(new_table->i32()->Get(0), 1);
+ EXPECT_EQ(new_table->i32()->Get(1), 2);
+ EXPECT_EQ(new_table->i32()->Get(2), 3);
+ ASSERT_NE(new_table->i64(), nullptr);
+ EXPECT_EQ(new_table->i64()->size(), 3);
+ EXPECT_EQ(new_table->i64()->Get(0), 1);
+ EXPECT_EQ(new_table->i64()->Get(1), 2);
+ EXPECT_EQ(new_table->i64()->Get(2), 3);
+ ASSERT_NE(new_table->u8(), nullptr);
+ EXPECT_EQ(new_table->u8()->size(), 3);
+ EXPECT_EQ(new_table->u8()->Get(0), 1);
+ EXPECT_EQ(new_table->u8()->Get(1), 2);
+ EXPECT_EQ(new_table->u8()->Get(2), 3);
+ ASSERT_NE(new_table->u16(), nullptr);
+ EXPECT_EQ(new_table->u16()->size(), 3);
+ EXPECT_EQ(new_table->u16()->Get(0), 1);
+ EXPECT_EQ(new_table->u16()->Get(1), 2);
+ EXPECT_EQ(new_table->u16()->Get(2), 3);
+ ASSERT_NE(new_table->u32(), nullptr);
+ EXPECT_EQ(new_table->u32()->size(), 3);
+ EXPECT_EQ(new_table->u32()->Get(0), 1);
+ EXPECT_EQ(new_table->u32()->Get(1), 2);
+ EXPECT_EQ(new_table->u32()->Get(2), 3);
+ ASSERT_NE(new_table->u64(), nullptr);
+ EXPECT_EQ(new_table->u64()->size(), 3);
+ EXPECT_EQ(new_table->u64()->Get(0), 1);
+ EXPECT_EQ(new_table->u64()->Get(1), 2);
+ EXPECT_EQ(new_table->u64()->Get(2), 3);
+ ASSERT_NE(new_table->f(), nullptr);
+ EXPECT_EQ(new_table->f()->size(), 3);
+ EXPECT_EQ(new_table->f()->Get(0), 1);
+ EXPECT_EQ(new_table->f()->Get(1), 2);
+ EXPECT_EQ(new_table->f()->Get(2), 3);
+ ASSERT_NE(new_table->d(), nullptr);
+ EXPECT_EQ(new_table->d()->size(), 3);
+ EXPECT_EQ(new_table->d()->Get(0), 1);
+ EXPECT_EQ(new_table->d()->Get(1), 2);
+ EXPECT_EQ(new_table->d()->Get(2), 3);
+ ASSERT_NE(new_table->e(), nullptr);
+ EXPECT_EQ(new_table->e()->size(), 3);
+ EXPECT_EQ(new_table->e()->Get(0), internal::TestFbsEnum_First);
+ EXPECT_EQ(new_table->e()->Get(1), internal::TestFbsEnum_Second);
+ EXPECT_EQ(new_table->e()->Get(2), internal::TestFbsEnum_Third);
+ EXPECT_EQ(new_table->str()->size(), 3);
+ EXPECT_EQ(new_table->str()->Get(0)->str(), "foo");
+ EXPECT_EQ(new_table->str()->Get(1)->str(), "bar");
+ EXPECT_EQ(new_table->str()->Get(2)->str(), "baz");
+ ASSERT_NE(new_table->t(), nullptr);
+ EXPECT_EQ(new_table->t()->size(), 3);
+ EXPECT_EQ(new_table->t()->Get(0)->b(), true);
+ EXPECT_EQ(new_table->t()->Get(1)->b(), true);
+ EXPECT_EQ(new_table->t()->Get(2)->b(), true);
+ EXPECT_EQ(new_table->t()->Get(0)->f(), 1.0);
+ EXPECT_EQ(new_table->t()->Get(1)->f(), 1.0);
+ EXPECT_EQ(new_table->t()->Get(2)->f(), 1.0);
+ EXPECT_EQ(new_table->t()->Get(0)->str()->str(), "foo");
+ EXPECT_EQ(new_table->t()->Get(1)->str()->str(), "bar");
+ EXPECT_EQ(new_table->t()->Get(2)->str()->str(), "baz");
+ EXPECT_EQ(new_table->t()->Get(0)->e(), internal::TestFbsEnum_Second);
+ EXPECT_EQ(new_table->t()->Get(1)->e(), internal::TestFbsEnum_Second);
+ EXPECT_EQ(new_table->t()->Get(2)->e(), internal::TestFbsEnum_Second);
+ }
+}
+
+TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllVectorFields) {
+ auto domain = Arbitrary<VectorsTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_changes = [&](std::string_view name, auto get) {
+ Set<typename decltype(get(val.user_value))::value_type> values;
+
+ int iterations = 10'000;
+ while (--iterations > 0 && values.size() < 2) {
+ auto value = get(val.user_value);
+ if (value.has_value()) {
+ values.insert(*value);
+ }
+ val.Mutate(domain, bitgen, {}, false);
+ }
+ EXPECT_GT(iterations, 0)
+ << "Field: " << name << " -- " << testing::PrintToString(values);
+ };
+
+ verify_field_changes("t.b", [](auto v) {
+ return v && v->t() ? std::make_optional(v->b()) : std::nullopt;
+ });
+ verify_field_changes("t.i8", [](auto v) {
+ return v && v->i8() ? std::make_optional(v->i8()) : std::nullopt;
+ });
+ verify_field_changes("t.i16", [](auto v) {
+ return v && v->i16() ? std::make_optional(v->i16()) : std::nullopt;
+ });
+ verify_field_changes("t.i32", [](auto v) {
+ return v && v->i32() ? std::make_optional(v->i32()) : std::nullopt;
+ });
+ verify_field_changes("t.i64", [](auto v) {
+ return v && v->i64() ? std::make_optional(v->i64()) : std::nullopt;
+ });
+ verify_field_changes("t.u8", [](auto v) {
+ return v && v->u8() ? std::make_optional(v->u8()) : std::nullopt;
+ });
+ verify_field_changes("t.u16", [](auto v) {
+ return v && v->u16() ? std::make_optional(v->u16()) : std::nullopt;
+ });
+ verify_field_changes("t.u32", [](auto v) {
+ return v && v->u32() ? std::make_optional(v->u32()) : std::nullopt;
+ });
+ verify_field_changes("t.u64", [](auto v) {
+ return v && v->u64() ? std::make_optional(v->u64()) : std::nullopt;
+ });
+ verify_field_changes("t.f", [](auto v) {
+ return v && v->f() ? std::make_optional(v->f()) : std::nullopt;
+ });
+ verify_field_changes("t.d", [](auto v) {
+ return v && v->d() ? std::make_optional(v->d()) : std::nullopt;
+ });
+ verify_field_changes("t.e", [](auto v) {
+ return v && v->e() ? std::make_optional(v->e()) : std::nullopt;
+ });
+ verify_field_changes("t.str", [](auto v) {
+ return v && v->str() ? std::make_optional(v->str()) : std::nullopt;
+ });
+ verify_field_changes("t.t", [](auto v) {
+ return v && v->t() ? std::make_optional(v->t()) : std::nullopt;
+ });
+}
+
+TEST(FlatbuffersTableDomainImplTest, UnionFieldsSerializeAndDeserialize) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto child_table_offset = internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second);
+ auto vec_child_one = internal::CreateSimpleTestFbsTableDirect(
+ fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second);
+ auto vec_child_two = internal::CreateOptionalRequiredTestFbsTableDirect(
+ fbb, true, std::nullopt, "foo bar baz");
+ auto vec_types =
+ fbb.CreateVector<uint8_t>({internal::Union_SimpleTestFbsTable,
+ internal::Union_OptionalRequiredTestFbsTable});
+ auto vec_values =
+ fbb.CreateVector({vec_child_one.Union(), vec_child_two.Union()});
+ auto parent_offset = internal::CreateUnionTestFbsTable(
+ fbb, internal::Union_SimpleTestFbsTable, child_table_offset.Union(),
+ vec_types, vec_values);
+ fbb.Finish(parent_offset);
+ auto table = flatbuffers::GetRoot<UnionTestFbsTable>(fbb.GetBufferPointer());
+
+ auto domain = Arbitrary<UnionTestFbsTable>();
+ auto corpus = domain.FromValue(table);
+ auto ir = domain.SerializeCorpus(*corpus);
+ auto new_corpus = domain.ParseCorpus(ir);
+ ASSERT_TRUE(new_corpus.has_value());
+ ASSERT_OK(domain.ValidateCorpusValue(*new_corpus));
+ auto new_table = domain.GetValue(*new_corpus);
+ ASSERT_NE(new_table, nullptr);
+ ASSERT_NE(new_table->u(), nullptr);
+ ASSERT_NE(new_table->u_as_SimpleTestFbsTable(), nullptr);
+ EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->b(), true);
+ EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->f(), 1.0);
+ EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->str()->str(), "foo bar baz");
+ EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->e(),
+ internal::TestFbsEnum_Second);
+
+ ASSERT_NE(new_table->u_vec(), nullptr);
+ ASSERT_EQ(new_table->u_vec()->size(), 2);
+ auto u_vec_one =
+ static_cast<const SimpleTestFbsTable*>(new_table->u_vec()->Get(0));
+ ASSERT_NE(u_vec_one, nullptr);
+ EXPECT_EQ(u_vec_one->b(), true);
+ EXPECT_EQ(u_vec_one->f(), 1.0);
+ EXPECT_EQ(u_vec_one->str()->str(), "foo bar baz");
+ EXPECT_EQ(u_vec_one->e(), internal::TestFbsEnum_Second);
+
+ auto u_vec_two = static_cast<const OptionalRequiredTestFbsTable*>(
+ new_table->u_vec()->Get(1));
+ ASSERT_NE(u_vec_two, nullptr);
+ EXPECT_EQ(u_vec_two->def_scalar(), true);
+ EXPECT_EQ(u_vec_two->opt_scalar(), std::nullopt);
+ ASSERT_NE(u_vec_two->req_str(), nullptr);
+ EXPECT_EQ(u_vec_two->req_str()->str(), "foo bar baz");
+ EXPECT_EQ(u_vec_two->opt_str(), nullptr);
+}
+
+TEST(FlatbuffersTableDomainImplTest, UnionFieldsEventuallyMutate) {
+ auto domain = Arbitrary<UnionTestFbsTable>();
+
+ absl::BitGen bitgen;
+ Value val(domain, bitgen);
+
+ const auto verify_field_changes = [&](std::string_view name, auto get) {
+ Set<decltype(get(val.user_value))> values;
+
+ int iterations = 10'000;
+ while (--iterations > 0 && values.size() < 2) {
+ auto value = get(val.user_value);
+ values.insert(value);
+ val.Mutate(domain, bitgen, {}, false);
+ }
+ EXPECT_GT(iterations, 0)
+ << "Field: " << name << " -- " << testing::PrintToString(values);
+ };
+
+ verify_field_changes("u_type", [](auto v) { return v->u_type(); });
+ verify_field_changes("u_as_SimpleTestFbsTable",
+ [](auto v) { return v->u_as_SimpleTestFbsTable(); });
+ verify_field_changes("u_as_OptionalRequiredTestFbsTable", [](auto v) {
+ return v->u_as_OptionalRequiredTestFbsTable();
+ });
+ verify_field_changes("u_vec_type", [](auto v) { return v->u_vec_type(); });
+ verify_field_changes("u_vec", [](auto v) { return v->u_vec(); });
+ verify_field_changes("u_vec[0].as_SimpleTestFbsTable", [](auto v) {
+ return v->u_vec() && v->u_vec()->size() > 0 &&
+ v->u_vec_type()->Get(0) == internal::Union_SimpleTestFbsTable
+ ? static_cast<const SimpleTestFbsTable*>(v->u_vec()->Get(0))
+ : nullptr;
+ });
+ verify_field_changes("u_vec[0].as_OptionalRequiredTestFbsTable", [](auto v) {
+ return v->u_vec() && v->u_vec()->size() > 0 &&
+ v->u_vec_type()->Get(0) ==
+ internal::Union_OptionalRequiredTestFbsTable
+ ? static_cast<const OptionalRequiredTestFbsTable*>(
+ v->u_vec()->Get(0))
+ : nullptr;
+ });
+}
+
+} // namespace
+} // namespace fuzztest
diff --git a/fuzztest/BUILD b/fuzztest/BUILD
index b604180..f89a96c 100644
--- a/fuzztest/BUILD
+++ b/fuzztest/BUILD
@@ -15,6 +15,7 @@
# FuzzTest: a coverage-guided fuzzing / property-based testing framework.
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_library_public")
load("@rules_proto//proto:defs.bzl", "proto_library")
package(default_visibility = ["//visibility:public"])
@@ -345,6 +346,7 @@
":serialization",
":status",
":type_support",
+ "@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/base:no_destructor",
"@abseil-cpp//absl/container:flat_hash_map",
@@ -423,6 +425,38 @@
)
cc_library(
+ name = "flatbuffers",
+ srcs = [
+ "internal/domains/flatbuffers_domain_impl.cc",
+ "internal/domains/flatbuffers_domain_impl.h",
+ ],
+ hdrs = ["flatbuffers.h"],
+ deps = [
+ ":any",
+ ":domain_core",
+ ":logging",
+ ":meta",
+ ":serialization",
+ ":status",
+ ":type_support",
+ "@abseil-cpp//absl/algorithm:container",
+ "@abseil-cpp//absl/base:core_headers",
+ "@abseil-cpp//absl/base:nullability",
+ "@abseil-cpp//absl/container:flat_hash_map",
+ "@abseil-cpp//absl/container:flat_hash_set",
+ "@abseil-cpp//absl/random",
+ "@abseil-cpp//absl/random:bit_gen_ref",
+ "@abseil-cpp//absl/random:distributions",
+ "@abseil-cpp//absl/status",
+ "@abseil-cpp//absl/status:statusor",
+ "@abseil-cpp//absl/strings",
+ "@abseil-cpp//absl/strings:str_format",
+ "@abseil-cpp//absl/synchronization",
+ "@flatbuffers//:runtime_cc",
+ ],
+)
+
+cc_library(
name = "fixture_driver",
srcs = ["internal/fixture_driver.cc"],
hdrs = ["internal/fixture_driver.h"],
@@ -799,6 +833,28 @@
deps = [":test_protobuf"],
)
+flatbuffer_library_public(
+ name = "test_flatbuffers_fbs",
+ srcs = ["internal/test_flatbuffers.fbs"],
+ outs = [
+ "test_flatbuffers_bfbs_generated.h",
+ "test_flatbuffers_generated.h",
+ ],
+ flatc_args = [
+ "--bfbs-gen-embed",
+ "--gen-name-strings",
+ ],
+ language_flag = "-c",
+)
+
+cc_library(
+ name = "test_flatbuffers_cc_fbs",
+ srcs = [":test_flatbuffers_fbs"],
+ hdrs = [":test_flatbuffers_fbs"],
+ features = ["-parse_headers"],
+ deps = ["@flatbuffers//:runtime_cc"],
+)
+
cc_library(
name = "type_support",
srcs = ["internal/type_support.cc"],
diff --git a/fuzztest/flatbuffers.h b/fuzztest/flatbuffers.h
new file mode 100644
index 0000000..b70ed36
--- /dev/null
+++ b/fuzztest/flatbuffers.h
@@ -0,0 +1,19 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUZZTEST_FUZZTEST_FLATBUFFERS_H_
+#define FUZZTEST_FUZZTEST_FLATBUFFERS_H_
+
+#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" // IWYU pragma: export
+#endif // FUZZTEST_FUZZTEST_FLATBUFFERS_H_
diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc
new file mode 100644
index 0000000..8a7950f
--- /dev/null
+++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc
@@ -0,0 +1,299 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h"
+
+#include <cstdint>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "absl/random/bit_gen_ref.h"
+#include "absl/random/distributions.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "absl/synchronization/mutex.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffer_builder.h"
+#include "flatbuffers/reflection_generated.h"
+#include "flatbuffers/table.h"
+#include "./fuzztest/domain_core.h"
+#include "./fuzztest/internal/any.h"
+#include "./fuzztest/internal/domains/domain_base.h"
+#include "./fuzztest/internal/domains/domain_type_erasure.h"
+#include "./fuzztest/internal/logging.h"
+#include "./fuzztest/internal/meta.h"
+#include "./fuzztest/internal/serialization.h"
+
+namespace fuzztest {
+namespace internal {
+
+FlatbuffersUnionDomainImpl::corpus_type FlatbuffersUnionDomainImpl::Init(
+ absl::BitGenRef prng) {
+ if (auto seed = this->MaybeGetRandomSeed(prng)) {
+ return *seed;
+ }
+ corpus_type val;
+ auto selected_type_enumval_index =
+ absl::Uniform(prng, 0ul, union_def_->values()->size());
+ auto type_enumval = union_def_->values()->Get(selected_type_enumval_index);
+ if (type_enumval == nullptr) {
+ return val;
+ }
+ auto type_value = type_domain_.FromValue(type_enumval->value());
+ if (!type_value.has_value()) {
+ return val;
+ }
+ val.first = *type_value;
+ if (type_enumval->value() == 0 /* NONE */) {
+ return val;
+ }
+
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return val;
+ }
+
+ auto inner_val = domain->Init(prng);
+ val.second = GenericDomainCorpusType(std::in_place_type<decltype(inner_val)>,
+ std::move(inner_val));
+ return val;
+}
+
+// Mutates the corpus value.
+void FlatbuffersUnionDomainImpl::Mutate(
+ corpus_type& val, absl::BitGenRef prng,
+ const domain_implementor::MutationMetadata& metadata, bool only_shrink) {
+ auto total_weight = CountNumberOfFields(val);
+ auto selected_weight = absl::Uniform(prng, 0ul, total_weight);
+ if (selected_weight == 0) {
+ type_domain_.Mutate(val.first, prng, metadata, only_shrink);
+ val.second = GenericDomainCorpusType(std::in_place_type<void*>, nullptr);
+ auto type_value = type_domain_.GetValue(val.first);
+ if (type_value == 0) {
+ return;
+ }
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return;
+ }
+ auto inner_val = domain->Init(prng);
+ val.second = GenericDomainCorpusType(
+ std::in_place_type<decltype(inner_val)>, std::move(inner_val));
+ } else {
+ auto type_value = type_domain_.GetValue(val.first);
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return;
+ }
+ auto inner_val = val.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>();
+ domain->MutateSelectedField(inner_val, prng, metadata, only_shrink,
+ selected_weight - 1);
+ }
+}
+
+uint64_t FlatbuffersUnionDomainImpl::CountNumberOfFields(corpus_type& val) {
+ uint64_t count = 1;
+ auto type_value = type_domain_.GetValue(val.first);
+ if (type_value == 0 /* NONE */) {
+ return count;
+ }
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return count;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain != nullptr) {
+ auto inner_val = val.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>();
+ count += domain->CountNumberOfFields(inner_val);
+ }
+ return count;
+}
+
+absl::Status FlatbuffersUnionDomainImpl::ValidateCorpusValue(
+ const corpus_type& corpus_value) const {
+ auto type_value = type_domain_.GetValue(corpus_value.first);
+ if (type_value == 0) {
+ return absl::OkStatus();
+ }
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return absl::OkStatus();
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return absl::OkStatus();
+ }
+ auto inner_corpus_value = corpus_value.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>();
+ return domain->ValidateCorpusValue(inner_corpus_value);
+}
+
+std::optional<FlatbuffersUnionDomainImpl::corpus_type>
+FlatbuffersUnionDomainImpl::FromValue(const value_type& value) const {
+ std::optional<FlatbuffersUnionDomainImpl::corpus_type> out{{}};
+ auto type_value = type_domain_.FromValue(value.first);
+ if (type_value.has_value()) {
+ out->first = *type_value;
+ }
+ auto type_enumval = union_def_->values()->LookupByKey(value.first);
+ if (type_enumval == nullptr) {
+ return std::nullopt;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain != nullptr) {
+ auto inner_value =
+ domain->FromValue(static_cast<const flatbuffers::Table*>(value.second));
+ if (inner_value.has_value()) {
+ out->second = GenericDomainCorpusType(
+ std::in_place_type<typename decltype(inner_value)::value_type>,
+ std::move(*inner_value));
+ }
+ }
+ return out;
+}
+
+// Converts the IRObject to a corpus value.
+std::optional<FlatbuffersUnionDomainImpl::corpus_type>
+FlatbuffersUnionDomainImpl::ParseCorpus(const IRObject& obj) const {
+ corpus_type out;
+ auto subs = obj.Subs();
+ if (!subs) {
+ return std::nullopt;
+ }
+ if (subs->size() != 2) {
+ return std::nullopt;
+ }
+
+ auto type_corpus = type_domain_.ParseCorpus((*subs)[0]);
+ if (!type_corpus.has_value()) {
+ return std::nullopt;
+ }
+ out.first = *type_corpus;
+ auto type_value = type_domain_.GetValue(out.first);
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return std::nullopt;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return std::nullopt;
+ }
+
+ auto inner_corpus = domain->ParseCorpus((*subs)[1]);
+ if (inner_corpus.has_value()) {
+ out.second = GenericDomainCorpusType(
+ std::in_place_type<
+ typename std::remove_pointer_t<decltype(inner_corpus)>::value_type>,
+ *inner_corpus);
+ }
+ return out;
+}
+
+// Converts the corpus value to an IRObject.
+IRObject FlatbuffersUnionDomainImpl::SerializeCorpus(
+ const corpus_type& value) const {
+ IRObject out;
+ auto& pair = out.MutableSubs();
+ pair.reserve(2);
+
+ auto type_value = type_domain_.GetValue(value.first);
+ pair.push_back(type_domain_.SerializeCorpus(value.first));
+
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return out;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return out;
+ }
+ pair.push_back(domain->SerializeCorpus(
+ value.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>()));
+ return out;
+}
+
+std::optional<flatbuffers::uoffset_t> FlatbuffersUnionDomainImpl::BuildValue(
+ const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const {
+ auto type_value = type_domain_.GetValue(value.first);
+ auto type_enumval = union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return std::nullopt;
+ }
+ auto domain = GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ return std::nullopt;
+ }
+ return domain->BuildTable(
+ value.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>(),
+ builder);
+}
+
+FlatbuffersTableUntypedDomainImpl* FlatbuffersUnionDomainImpl::GetTableDomain(
+ const reflection::EnumVal& enum_value) const {
+ absl::MutexLock l(&mutex_);
+ auto it = domains_.find(enum_value.value());
+ if (it == domains_.end()) {
+ auto base_type = enum_value.union_type()->base_type();
+ if (base_type == reflection::BaseType::None) {
+ return nullptr;
+ }
+ FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::Obj,
+ "EnumVal union type is not a BaseType::Obj");
+ auto object = schema_->objects()->Get(enum_value.union_type()->index());
+ if (object->is_struct()) {
+ // TODO(b/405939014): Support structs.
+ return nullptr;
+ }
+ it = domains_
+ .emplace(enum_value.value(),
+ FlatbuffersTableUntypedDomainImpl{schema_, object})
+ .first;
+ }
+ return &it->second;
+}
+
+void FlatbuffersUnionDomainImpl::Printer::PrintCorpusValue(
+ const corpus_type& value, domain_implementor::RawSink out,
+ domain_implementor::PrintMode mode) const {
+ auto type_value = self.type_domain_.GetValue(value.first);
+ auto type_enumval = self.union_def_->values()->LookupByKey(type_value);
+ if (type_enumval == nullptr) {
+ return;
+ }
+ absl::Format(out, "<%s>(", type_enumval->name()->str());
+ auto domain = self.GetTableDomain(*type_enumval);
+ if (domain == nullptr) {
+ absl::Format(out, "UNSUPPORTED_UNION_TYPE");
+ return;
+ }
+ auto inner_corpus_value = value.second.template GetAs<
+ corpus_type_t<std::remove_pointer_t<decltype(domain)>>>();
+ domain_implementor::PrintValue(*domain, inner_corpus_value, out, mode);
+ absl::Format(out, ")");
+}
+} // namespace internal
+} // namespace fuzztest
diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h
new file mode 100644
index 0000000..5dffa5d
--- /dev/null
+++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h
@@ -0,0 +1,1535 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_
+#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <list>
+#include <optional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/base/nullability.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/bit_gen_ref.h"
+#include "absl/random/distributions.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/synchronization/mutex.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffer_builder.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+#include "flatbuffers/string.h"
+#include "flatbuffers/table.h"
+#include "flatbuffers/vector.h"
+#include "flatbuffers/verifier.h"
+#include "./fuzztest/domain_core.h"
+#include "./fuzztest/internal/any.h"
+#include "./fuzztest/internal/domains/arbitrary_impl.h"
+#include "./fuzztest/internal/domains/domain_base.h"
+#include "./fuzztest/internal/domains/domain_type_erasure.h"
+#include "./fuzztest/internal/domains/element_of_impl.h"
+#include "./fuzztest/internal/logging.h"
+#include "./fuzztest/internal/meta.h"
+#include "./fuzztest/internal/serialization.h"
+#include "./fuzztest/internal/status.h"
+
+namespace fuzztest::internal {
+
+template <typename Underlying,
+ typename = std::enable_if_t<std::is_integral_v<Underlying> &&
+ !std::is_same_v<Underlying, bool>>>
+struct FlatbuffersEnumTag {
+ using type = Underlying;
+};
+
+template <typename T>
+struct is_flatbuffers_enum_tag : std::false_type {};
+
+template <typename Underlying>
+struct is_flatbuffers_enum_tag<FlatbuffersEnumTag<Underlying>>
+ : std::true_type {};
+
+template <typename T>
+inline constexpr bool is_flatbuffers_enum_tag_v =
+ is_flatbuffers_enum_tag<T>::value;
+
+template <typename T>
+struct FlatbuffersVectorTag {
+ using value_type = T;
+};
+
+template <typename T>
+struct is_flatbuffers_vector_tag : std::false_type {};
+
+template <typename T>
+struct is_flatbuffers_vector_tag<FlatbuffersVectorTag<T>> : std::true_type {};
+
+template <typename T>
+inline constexpr bool is_flatbuffers_vector_tag_v =
+ is_flatbuffers_vector_tag<T>::value;
+
+struct FlatbuffersArrayTag;
+struct FlatbuffersTableTag;
+struct FlatbuffersStructTag;
+struct FlatbuffersUnionTag;
+
+// Dynamic to static dispatch visitor pattern for flatbuffers vector elements.
+template <typename Visitor>
+auto VisitFlatbufferVectorElementField(const reflection::Schema* schema,
+ const reflection::Field* field,
+ Visitor visitor) {
+ auto field_index = field->type()->index();
+ auto element_type = field->type()->element();
+ switch (element_type) {
+ case reflection::BaseType::Bool:
+ visitor.template Visit<FlatbuffersVectorTag<bool>>(field);
+ break;
+ case reflection::BaseType::Byte:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<int8_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<int8_t>>(field);
+ }
+ break;
+ case reflection::BaseType::Short:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<int16_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<int16_t>>(field);
+ }
+ break;
+ case reflection::BaseType::Int:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<int32_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<int32_t>>(field);
+ }
+ break;
+ case reflection::BaseType::Long:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<int64_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<int64_t>>(field);
+ }
+ break;
+ case reflection::BaseType::UByte:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<uint8_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<uint8_t>>(field);
+ }
+ break;
+ case reflection::BaseType::UShort:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<uint16_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<uint16_t>>(field);
+ }
+ break;
+ case reflection::BaseType::UInt:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<uint32_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<uint32_t>>(field);
+ }
+ break;
+ case reflection::BaseType::ULong:
+ if (field_index >= 0) {
+ visitor
+ .template Visit<FlatbuffersVectorTag<FlatbuffersEnumTag<uint64_t>>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<uint64_t>>(field);
+ }
+ break;
+ case reflection::BaseType::Float:
+ visitor.template Visit<FlatbuffersVectorTag<float>>(field);
+ break;
+ case reflection::BaseType::Double:
+ visitor.template Visit<FlatbuffersVectorTag<double>>(field);
+ break;
+ case reflection::BaseType::String:
+ visitor.template Visit<FlatbuffersVectorTag<std::string>>(field);
+ break;
+ case reflection::BaseType::Obj: {
+ auto sub_object = schema->objects()->Get(field_index);
+ if (sub_object->is_struct()) {
+ visitor.template Visit<FlatbuffersVectorTag<FlatbuffersStructTag>>(
+ field);
+ } else {
+ visitor.template Visit<FlatbuffersVectorTag<FlatbuffersTableTag>>(
+ field);
+ }
+ break;
+ }
+ case reflection::BaseType::Union:
+ visitor.template Visit<FlatbuffersVectorTag<FlatbuffersUnionTag>>(field);
+ break;
+ case reflection::BaseType::UType:
+ // Noop: Union types are visited at the same time as their corresponding
+ // union values.
+ break;
+ default: // Vector of vectors and vector of arrays are not supported.
+ FUZZTEST_INTERNAL_CHECK(false, "Unsupported vector base type");
+ }
+}
+
+// Dynamic to static dispatch visitor pattern.
+template <typename Visitor>
+auto VisitFlatbufferField(const reflection::Schema* absl_nonnull schema,
+ const reflection::Field* absl_nonnull field,
+ Visitor visitor) {
+ auto field_index = field->type()->index();
+ switch (field->type()->base_type()) {
+ case reflection::BaseType::Bool:
+ visitor.template Visit<bool>(field);
+ break;
+ case reflection::BaseType::Byte:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<int8_t>>(field);
+ } else {
+ visitor.template Visit<int8_t>(field);
+ }
+ break;
+ case reflection::BaseType::Short:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<int16_t>>(field);
+ } else {
+ visitor.template Visit<int16_t>(field);
+ }
+ break;
+ case reflection::BaseType::Int:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<int32_t>>(field);
+ } else {
+ visitor.template Visit<int32_t>(field);
+ }
+ break;
+ case reflection::BaseType::Long:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<int64_t>>(field);
+ } else {
+ visitor.template Visit<int64_t>(field);
+ }
+ break;
+ case reflection::BaseType::UByte:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<uint8_t>>(field);
+ } else {
+ visitor.template Visit<uint8_t>(field);
+ }
+ break;
+ case reflection::BaseType::UShort:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<uint16_t>>(field);
+ } else {
+ visitor.template Visit<uint16_t>(field);
+ }
+ break;
+ case reflection::BaseType::UInt:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<uint32_t>>(field);
+ } else {
+ visitor.template Visit<uint32_t>(field);
+ }
+ break;
+ case reflection::BaseType::ULong:
+ if (field_index >= 0) {
+ visitor.template Visit<FlatbuffersEnumTag<uint64_t>>(field);
+ } else {
+ visitor.template Visit<uint64_t>(field);
+ }
+ break;
+ case reflection::BaseType::Float:
+ visitor.template Visit<float>(field);
+ break;
+ case reflection::BaseType::Double:
+ visitor.template Visit<double>(field);
+ break;
+ case reflection::BaseType::String:
+ visitor.template Visit<std::string>(field);
+ break;
+ case reflection::BaseType::Vector:
+ case reflection::BaseType::Vector64: {
+ VisitFlatbufferVectorElementField<Visitor>(schema, field, visitor);
+ break;
+ case reflection::BaseType::Array:
+ visitor.template Visit<FlatbuffersArrayTag>(field);
+ break;
+ case reflection::BaseType::Obj: {
+ auto sub_object = schema->objects()->Get(field->type()->index());
+ if (sub_object->is_struct()) {
+ visitor.template Visit<FlatbuffersStructTag>(field);
+ } else {
+ visitor.template Visit<FlatbuffersTableTag>(field);
+ }
+ break;
+ }
+ case reflection::BaseType::Union:
+ visitor.template Visit<FlatbuffersUnionTag>(field);
+ break;
+ case reflection::BaseType::UType:
+ // Noop: Union types are visited at the same time as their corresponding
+ // union values.
+ break;
+ default:
+ FUZZTEST_INTERNAL_CHECK(false, "Unsupported base type");
+ }
+ }
+}
+
+// Flatbuffers enum domain implementation.
+template <typename Underlaying>
+class FlatbuffersEnumDomainImpl : public ElementOfImpl<Underlaying> {
+ public:
+ using typename ElementOfImpl<Underlaying>::DomainBase::corpus_type;
+ using typename ElementOfImpl<Underlaying>::DomainBase::value_type;
+
+ explicit FlatbuffersEnumDomainImpl(const reflection::Enum* enum_def)
+ : ElementOfImpl<Underlaying>(GetEnumValues(enum_def)),
+ enum_def_(enum_def) {}
+
+ auto GetPrinter() const { return Printer{*this}; }
+
+ private:
+ const reflection::Enum* enum_def_;
+
+ static std::vector<value_type> GetEnumValues(
+ const reflection::Enum* enum_def) {
+ std::vector<value_type> values;
+ values.reserve(enum_def->values()->size());
+ for (const auto* value : *enum_def->values()) {
+ values.push_back(value->value());
+ }
+ return values;
+ }
+
+ struct Printer {
+ const FlatbuffersEnumDomainImpl& self;
+ void PrintCorpusValue(const corpus_type& value,
+ domain_implementor::RawSink out,
+ domain_implementor::PrintMode mode) const {
+ if (mode == domain_implementor::PrintMode::kHumanReadable) {
+ auto user_value = self.GetValue(value);
+ absl::Format(
+ out, "%s",
+ self.enum_def_->values()->LookupByKey(user_value)->name()->str());
+ } else {
+ absl::Format(out, "%d", value);
+ }
+ }
+ };
+};
+
+using FlatbuffersUnionTypeDomainImpl = FlatbuffersEnumDomainImpl<
+ decltype(static_cast<reflection::EnumVal*>(nullptr)->value())>;
+
+class FlatbuffersTableUntypedDomainImpl;
+
+// Flatbuffers union domain implementation.
+class FlatbuffersUnionDomainImpl
+ : public domain_implementor::DomainBase<
+ FlatbuffersUnionDomainImpl,
+ std::pair<typename FlatbuffersUnionTypeDomainImpl::value_type,
+ const void*>,
+ std::pair<typename FlatbuffersUnionTypeDomainImpl::corpus_type,
+ GenericDomainCorpusType>> {
+ public:
+ using typename FlatbuffersUnionDomainImpl::DomainBase::corpus_type;
+ using typename FlatbuffersUnionDomainImpl::DomainBase::value_type;
+
+ FlatbuffersUnionDomainImpl(const reflection::Schema* schema,
+ const reflection::Enum* union_def)
+ : schema_(schema), union_def_(union_def), type_domain_(union_def) {}
+
+ FlatbuffersUnionDomainImpl(const FlatbuffersUnionDomainImpl& other)
+ : schema_(other.schema_),
+ union_def_(other.union_def_),
+ type_domain_(other.type_domain_) {
+ absl::MutexLock l(&other.mutex_);
+ domains_ = other.domains_;
+ }
+
+ FlatbuffersUnionDomainImpl(FlatbuffersUnionDomainImpl&& other)
+ : schema_(other.schema_),
+ union_def_(other.union_def_),
+ type_domain_(std::move(other.type_domain_)) {
+ absl::MutexLock l(&other.mutex_);
+ domains_ = std::move(other.domains_);
+ }
+
+ FlatbuffersUnionDomainImpl& operator=(
+ const FlatbuffersUnionDomainImpl& other) {
+ schema_ = other.schema_;
+ union_def_ = other.union_def_;
+ type_domain_ = other.type_domain_;
+ absl::MutexLock l(&other.mutex_);
+ domains_ = other.domains_;
+ return *this;
+ }
+
+ FlatbuffersUnionDomainImpl& operator=(FlatbuffersUnionDomainImpl&& other) {
+ schema_ = other.schema_;
+ union_def_ = other.union_def_;
+ type_domain_ = std::move(other.type_domain_);
+ absl::MutexLock l(&other.mutex_);
+ domains_ = std::move(other.domains_);
+ return *this;
+ }
+
+ // Initializes the corpus value.
+ corpus_type Init(absl::BitGenRef prng);
+
+ // Mutates the corpus value.
+ void Mutate(corpus_type& val, absl::BitGenRef prng,
+ const domain_implementor::MutationMetadata& metadata,
+ bool only_shrink);
+
+ uint64_t CountNumberOfFields(corpus_type& val);
+
+ auto GetPrinter() const { return Printer{*this}; }
+
+ absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const;
+
+ // UNSUPPORTED: Flatbuffers unions user values are not supported.
+ value_type GetValue(const corpus_type& value) const {
+ FUZZTEST_INTERNAL_CHECK(false, "GetValue is not supported for unions.");
+ }
+
+ auto GetType(const corpus_type& value) const {
+ return type_domain_.GetValue(value.first);
+ }
+
+ std::optional<flatbuffers::uoffset_t> BuildValue(
+ const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const;
+
+ std::optional<corpus_type> FromValue(const value_type& value) const;
+
+ // Converts the IRObject to a corpus value.
+ std::optional<corpus_type> ParseCorpus(const IRObject& obj) const;
+
+ // Converts the corpus value to an IRObject.
+ IRObject SerializeCorpus(const corpus_type& value) const;
+
+ private:
+ const reflection::Schema* schema_;
+ const reflection::Enum* union_def_;
+ FlatbuffersEnumDomainImpl<typename value_type::first_type> type_domain_;
+ mutable absl::Mutex mutex_;
+ mutable absl::flat_hash_map<typename value_type::first_type,
+ FlatbuffersTableUntypedDomainImpl>
+ domains_ ABSL_GUARDED_BY(mutex_);
+
+ FlatbuffersTableUntypedDomainImpl* GetTableDomain(
+ const reflection::EnumVal& enum_value) const;
+
+ struct Printer {
+ const FlatbuffersUnionDomainImpl& self;
+
+ void PrintCorpusValue(const corpus_type& value,
+ domain_implementor::RawSink out,
+ domain_implementor::PrintMode mode) const;
+ };
+};
+
+// Domain implementation for flatbuffers untyped tables.
+// The corpus type is a pair of:
+// - A map of field ids to field values.
+// - The serialized buffer of the table.
+class FlatbuffersTableUntypedDomainImpl
+ : public domain_implementor::DomainBase<
+ FlatbuffersTableUntypedDomainImpl,
+ const flatbuffers::Table* absl_nonnull,
+ absl::flat_hash_map<
+ decltype(static_cast<reflection::Field*>(nullptr)->id()),
+ GenericDomainCorpusType>> {
+ public:
+ using typename FlatbuffersTableUntypedDomainImpl::DomainBase::corpus_type;
+ using typename FlatbuffersTableUntypedDomainImpl::DomainBase::value_type;
+
+ explicit FlatbuffersTableUntypedDomainImpl(
+ const reflection::Schema* schema, const reflection::Object* table_object)
+ : schema_(schema), table_object_(table_object) {}
+
+ FlatbuffersTableUntypedDomainImpl(
+ const FlatbuffersTableUntypedDomainImpl& other)
+ : schema_(other.schema_), table_object_(other.table_object_) {
+ absl::MutexLock l(&other.mutex_);
+ domains_ = other.domains_;
+ }
+
+ FlatbuffersTableUntypedDomainImpl& operator=(
+ const FlatbuffersTableUntypedDomainImpl& other) {
+ schema_ = other.schema_;
+ table_object_ = other.table_object_;
+ absl::MutexLock l(&other.mutex_);
+ domains_ = other.domains_;
+ return *this;
+ }
+
+ FlatbuffersTableUntypedDomainImpl(FlatbuffersTableUntypedDomainImpl&& other)
+ : schema_(other.schema_), table_object_(other.table_object_) {
+ absl::MutexLock l(&other.mutex_);
+ domains_ = std::move(other.domains_);
+ }
+
+ FlatbuffersTableUntypedDomainImpl& operator=(
+ FlatbuffersTableUntypedDomainImpl&& other) {
+ schema_ = other.schema_;
+ table_object_ = other.table_object_;
+ absl::MutexLock l(&other.mutex_);
+ domains_ = std::move(other.domains_);
+ return *this;
+ }
+
+ // Initializes the corpus value.
+ corpus_type Init(absl::BitGenRef prng) {
+ if (auto seed = this->MaybeGetRandomSeed(prng)) {
+ return *seed;
+ }
+ corpus_type val;
+ for (const auto* field : *table_object_->fields()) {
+ VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val});
+ }
+ return val;
+ }
+
+ // Mutates the corpus value.
+ void Mutate(corpus_type& val, absl::BitGenRef prng,
+ const domain_implementor::MutationMetadata& metadata,
+ bool only_shrink) {
+ auto total_weight = CountNumberOfFields(val);
+ auto selected_weight =
+ absl::Uniform(absl::IntervalClosedClosed, prng, 0ul, total_weight - 1);
+
+ MutateSelectedField(val, prng, metadata, only_shrink, selected_weight);
+ }
+
+ // Returns the domain for the given vector field.
+ template <typename Element>
+ auto GetDomainForVectorField(const reflection::Field* field) const {
+ if constexpr (is_flatbuffers_enum_tag_v<Element>) {
+ auto enum_object = schema_->enums()->Get(field->type()->index());
+ auto inner = OptionalOf(
+ VectorOf(
+ FlatbuffersEnumDomainImpl<typename Element::type>(enum_object))
+ .WithMaxSize(std::numeric_limits<flatbuffers::uoffset_t>::max()));
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<value_type_t<decltype(inner)>>{inner};
+ } else if constexpr (std::is_same_v<Element, FlatbuffersTableTag>) {
+ auto table_object = schema_->objects()->Get(field->type()->index());
+ auto inner = OptionalOf(
+ VectorOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object})
+ .WithMaxSize(std::numeric_limits<flatbuffers::uoffset_t>::max()));
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<std::optional<std::vector<const flatbuffers::Table*>>>{
+ inner};
+ } else if constexpr (std::is_same_v<Element, FlatbuffersStructTag>) {
+ // TODO(b/399123660): implement this.
+ return Domain<std::optional<bool>>(OptionalOf(ArbitraryImpl<bool>()));
+ } else if constexpr (std::is_same_v<Element, FlatbuffersUnionTag>) {
+ auto union_type = schema_->enums()->Get(field->type()->index());
+ auto inner = OptionalOf(
+ VectorOf(FlatbuffersUnionDomainImpl{schema_, union_type})
+ .WithMaxSize(std::numeric_limits<flatbuffers::uoffset_t>::max()));
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<value_type_t<decltype(inner)>>{inner};
+ } else {
+ auto inner = OptionalOf(
+ VectorOf(ArbitraryImpl<Element>())
+ .WithMaxSize(std::numeric_limits<flatbuffers::uoffset_t>::max()));
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<std::optional<std::vector<Element>>>{inner};
+ }
+ }
+
+ // Returns the domain for the given field.
+ template <typename T>
+ auto GetDomainForField(const reflection::Field* field) const {
+ if constexpr (std::is_same_v<T, FlatbuffersArrayTag>) {
+ // TODO(b/399123660): Implement this.
+ return Domain<std::optional<bool>>(OptionalOf(ArbitraryImpl<bool>()));
+ } else if constexpr (is_flatbuffers_enum_tag_v<T>) {
+ auto enum_object = schema_->enums()->Get(field->type()->index());
+ auto domain =
+ OptionalOf(FlatbuffersEnumDomainImpl<typename T::type>(enum_object));
+ if (!field->optional()) {
+ domain.SetWithoutNull();
+ }
+ return Domain<value_type_t<decltype(domain)>>{domain};
+ } else if constexpr (std::is_same_v<T, FlatbuffersTableTag>) {
+ auto table_object = schema_->objects()->Get(field->type()->index());
+ auto inner =
+ OptionalOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object});
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<std::optional<const flatbuffers::Table*>>{inner};
+ } else if constexpr (std::is_same_v<T, FlatbuffersStructTag>) {
+ // TODO(b/399123660): Implement this.
+ return Domain<std::optional<bool>>(OptionalOf(ArbitraryImpl<bool>()));
+ } else if constexpr (std::is_same_v<T, FlatbuffersUnionTag>) {
+ auto union_type = schema_->enums()->Get(field->type()->index());
+ auto inner = OptionalOf(FlatbuffersUnionDomainImpl{schema_, union_type});
+ return Domain<value_type_t<decltype(inner)>>{inner};
+ } else if constexpr (is_flatbuffers_vector_tag_v<T>) {
+ return GetDomainForVectorField<typename T::value_type>(field);
+ } else {
+ auto inner = OptionalOf(ArbitraryImpl<T>());
+ if (!field->optional()) {
+ inner.SetWithoutNull();
+ }
+ return Domain<std::optional<T>>{inner};
+ }
+ }
+
+ // Returns the domain for the given field.
+ // The domain is cached, and the same instance is returned for the same
+ // field.
+ template <typename T>
+ auto& GetSubDomain(const reflection::Field* field) const {
+ using DomainT = decltype(GetDomainForField<T>(field));
+ // Do the operation under a lock to prevent race conditions in `const`
+ // methods.
+ absl::MutexLock l(&mutex_);
+ auto it = domains_.find(field->id());
+ if (it == domains_.end()) {
+ it = domains_
+ .try_emplace(field->id(), std::in_place_type<DomainT>,
+ GetDomainForField<T>(field))
+ .first;
+ }
+ return it->second.template GetAs<DomainT>();
+ }
+
+ // Counts the number of fields that can be mutated.
+ // Returns the number of fields in the flattened tree for supported field
+ // types.
+ uint64_t CountNumberOfFields(corpus_type& val) {
+ uint64_t total_weight = 0;
+ for (const auto* field : *table_object_->fields()) {
+ reflection::BaseType base_type = field->type()->base_type();
+ if (IsScalarType(base_type) ||
+ base_type == reflection::BaseType::String) {
+ ++total_weight;
+ } else if (base_type == reflection::BaseType::Obj) {
+ auto sub_object = schema_->objects()->Get(field->type()->index());
+ // TODO(b/405939014): Support structs.
+ if (!sub_object->is_struct()) {
+ ++total_weight;
+ auto& sub_domain = GetSubDomain<FlatbuffersTableTag>(field);
+ total_weight += sub_domain.CountNumberOfFields(val[field->id()]);
+ }
+ } else if (base_type == reflection::BaseType::Vector ||
+ base_type == reflection::BaseType::Vector64) {
+ ++total_weight;
+ auto elem_type = field->type()->element();
+ if (IsScalarType(elem_type) ||
+ elem_type == reflection::BaseType::String) {
+ ++total_weight;
+ } else if (elem_type == reflection::BaseType::Obj) {
+ auto sub_object = schema_->objects()->Get(field->type()->index());
+ if (!sub_object->is_struct()) {
+ ++total_weight;
+ auto sub_domain =
+ GetSubDomain<FlatbuffersVectorTag<FlatbuffersTableTag>>(field);
+ total_weight += sub_domain.CountNumberOfFields(val[field->id()]);
+ }
+ } else if (elem_type == reflection::BaseType::Union) {
+ ++total_weight;
+ auto& sub_domain =
+ GetSubDomain<FlatbuffersVectorTag<FlatbuffersUnionTag>>(field);
+ total_weight += sub_domain.CountNumberOfFields(val[field->id()]);
+ }
+ } else if (base_type == reflection::BaseType::Union) {
+ ++total_weight;
+ auto& sub_domain = GetSubDomain<FlatbuffersUnionTag>(field);
+ total_weight += sub_domain.CountNumberOfFields(val[field->id()]);
+ }
+ }
+ return total_weight;
+ }
+
+ // Mutates the selected field.
+ // The selected field index is based on the flattened tree.
+ uint64_t MutateSelectedField(
+ corpus_type& val, absl::BitGenRef prng,
+ const domain_implementor::MutationMetadata& metadata, bool only_shrink,
+ uint64_t selected_field_index) {
+ uint64_t field_counter = 0;
+ for (const auto* field : *table_object_->fields()) {
+ ++field_counter;
+
+ if (field_counter == selected_field_index + 1) {
+ VisitFlatbufferField(
+ schema_, field,
+ MutateVisitor{*this, prng, metadata, only_shrink, val});
+ return field_counter;
+ }
+
+ auto base_type = field->type()->base_type();
+ if (base_type == reflection::BaseType::Obj) {
+ auto sub_object = schema_->objects()->Get(field->type()->index());
+ if (!sub_object->is_struct()) {
+ field_counter +=
+ GetSubDomain<FlatbuffersTableTag>(field).MutateSelectedField(
+ val[field->id()], prng, metadata, only_shrink,
+ selected_field_index - field_counter);
+ }
+ }
+
+ if (base_type == reflection::BaseType::Vector ||
+ base_type == reflection::BaseType::Vector64) {
+ auto elem_type = field->type()->element();
+ if (elem_type == reflection::BaseType::Obj) {
+ auto sub_object = schema_->objects()->Get(field->type()->index());
+ if (!sub_object->is_struct()) {
+ field_counter +=
+ GetSubDomain<FlatbuffersVectorTag<FlatbuffersTableTag>>(field)
+ .MutateSelectedField(val[field->id()], prng, metadata,
+ only_shrink,
+ selected_field_index - field_counter);
+ }
+ } else if (elem_type == reflection::BaseType::Union) {
+ field_counter +=
+ GetSubDomain<FlatbuffersVectorTag<FlatbuffersUnionTag>>(field)
+ .MutateSelectedField(val[field->id()], prng, metadata,
+ only_shrink,
+ selected_field_index - field_counter);
+ }
+ }
+
+ if (base_type == reflection::BaseType::Union) {
+ field_counter +=
+ GetSubDomain<FlatbuffersUnionTag>(field).MutateSelectedField(
+ val[field->id()], prng, metadata, only_shrink,
+ selected_field_index - field_counter);
+ }
+
+ if (field_counter > selected_field_index) {
+ return field_counter;
+ }
+ }
+ return field_counter;
+ }
+
+ auto GetPrinter() const { return Printer{*this}; }
+
+ absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
+ for (const auto& [id, field_corpus] : corpus_value) {
+ absl::Nullable<const reflection::Field*> field = GetFieldById(id);
+ if (field == nullptr) continue;
+ absl::Status result;
+ VisitFlatbufferField(schema_, field,
+ ValidateVisitor{*this, field_corpus, result});
+ if (!result.ok()) return result;
+ }
+ return absl::OkStatus();
+ }
+
+ value_type GetValue(const corpus_type& value) const {
+ // Untyped domain does not support GetValue since if it is a nested table
+ // it would need the top level table corpus value to be able to build it.
+ return nullptr;
+ }
+
+ // Converts the table pointer to a corpus value.
+ std::optional<corpus_type> FromValue(const value_type& value) const {
+ if (value == nullptr) {
+ return std::nullopt;
+ }
+ corpus_type ret;
+ for (const auto* field : *table_object_->fields()) {
+ VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret});
+ }
+ return ret;
+ }
+
+ // Converts the IRObject to a corpus value.
+ std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
+ corpus_type out;
+ auto subs = obj.Subs();
+ if (!subs) {
+ return std::nullopt;
+ }
+ out.reserve(subs->size());
+ for (const auto& sub : *subs) {
+ auto pair_subs = sub.Subs();
+ if (!pair_subs || pair_subs->size() != 2) {
+ return std::nullopt;
+ }
+ auto id = (*pair_subs)[0].GetScalar<typename corpus_type::key_type>();
+ if (!id.has_value()) {
+ return std::nullopt;
+ }
+ absl::Nullable<const reflection::Field*> field = GetFieldById(id.value());
+ if (field == nullptr) {
+ return std::nullopt;
+ }
+ std::optional<GenericDomainCorpusType> inner_parsed;
+ VisitFlatbufferField(schema_, field,
+ ParseVisitor{*this, (*pair_subs)[1], inner_parsed});
+ if (!inner_parsed) {
+ return std::nullopt;
+ }
+ out[id.value()] = *std::move(inner_parsed);
+ }
+ return out;
+ }
+
+ // Converts the corpus value to an IRObject.
+ IRObject SerializeCorpus(const corpus_type& value) const {
+ IRObject out;
+ auto& subs = out.MutableSubs();
+ subs.reserve(value.size());
+ for (const auto& [id, field_corpus] : value) {
+ absl::Nullable<const reflection::Field*> field = GetFieldById(id);
+ if (field == nullptr) {
+ continue;
+ }
+ IRObject& pair = subs.emplace_back();
+ auto& pair_subs = pair.MutableSubs();
+ pair_subs.reserve(2);
+ pair_subs.emplace_back(field->id());
+ VisitFlatbufferField(
+ schema_, field,
+ SerializeVisitor{*this, field_corpus, pair_subs.emplace_back()});
+ }
+ return out;
+ }
+
+ uint32_t BuildTable(const corpus_type& value,
+ flatbuffers::FlatBufferBuilder& builder) const {
+ // Add all the fields to the builder.
+ absl::flat_hash_map<typename corpus_type::key_type, flatbuffers::uoffset_t>
+ offsets;
+ for (const auto& [id, field_corpus] : value) {
+ absl::Nullable<const reflection::Field*> field = GetFieldById(id);
+ if (field == nullptr) {
+ continue;
+ }
+ VisitFlatbufferField(
+ schema_, field,
+ TableFieldBuilderVisitor{*this, builder, offsets, field_corpus});
+ }
+ // Build the table with the out of line fields offsets and inline fields.
+ uint32_t table_start = builder.StartTable();
+ for (const auto& [id, field_corpus] : value) {
+ absl::Nullable<const reflection::Field*> field = GetFieldById(id);
+ if (field == nullptr) {
+ continue;
+ }
+ VisitFlatbufferField(
+ schema_, field,
+ TableBuilderVisitor{*this, builder, offsets, field_corpus});
+ }
+ return builder.EndTable(table_start);
+ }
+
+ private:
+ const reflection::Schema* absl_nonnull schema_;
+ const reflection::Object* absl_nonnull table_object_;
+ mutable absl::Mutex mutex_;
+ mutable absl::flat_hash_map<typename corpus_type::key_type, CopyableAny>
+ domains_ ABSL_GUARDED_BY(mutex_);
+
+ absl::Nullable<const reflection::Field*> GetFieldById(
+ typename corpus_type::key_type id) const {
+ const auto it =
+ absl::c_find_if(*table_object_->fields(),
+ [id](const auto* field) { return field->id() == id; });
+ return it != table_object_->fields()->end() ? *it : nullptr;
+ }
+
+ bool IsScalarType(reflection::BaseType base_type) const {
+ switch (base_type) {
+ case reflection::BaseType::Bool:
+ case reflection::BaseType::Byte:
+ case reflection::BaseType::Short:
+ case reflection::BaseType::Int:
+ case reflection::BaseType::Long:
+ case reflection::BaseType::UByte:
+ case reflection::BaseType::UShort:
+ case reflection::BaseType::UInt:
+ case reflection::BaseType::ULong:
+ case reflection::BaseType::Float:
+ case reflection::BaseType::Double:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ bool IsTypeSupported(reflection::BaseType base_type) const {
+ return IsScalarType(base_type) || base_type == reflection::BaseType::String;
+ }
+
+ struct SerializeVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ const GenericDomainCorpusType& corpus_value;
+ IRObject& out;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) {
+ out = self.GetSubDomain<T>(field).SerializeCorpus(corpus_value);
+ }
+ };
+
+ struct FromValueVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ value_type value;
+ corpus_type& out;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) const {
+ [[maybe_unused]]
+ reflection::BaseType base_type = field->type()->base_type();
+ auto& domain = self.GetSubDomain<T>(field);
+ value_type_t<std::decay_t<decltype(domain)>> inner_value;
+
+ if constexpr (is_flatbuffers_enum_tag_v<T>) {
+ FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Byte &&
+ base_type <= reflection::BaseType::ULong,
+ "Field must be an enum type.");
+ if (field->optional() && !value->CheckField(field->offset())) {
+ inner_value = std::nullopt;
+ } else {
+ inner_value =
+ std::make_optional(value->template GetField<typename T::type>(
+ field->offset(), field->default_integer()));
+ }
+ } else if constexpr (std::is_integral_v<T>) {
+ FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Bool &&
+ base_type <= reflection::BaseType::ULong,
+ "Field must be an integer type.");
+ if (field->optional() && !value->CheckField(field->offset())) {
+ inner_value = std::nullopt;
+ } else {
+ inner_value = std::make_optional(value->template GetField<T>(
+ field->offset(), field->default_integer()));
+ }
+ } else if constexpr (std::is_floating_point_v<T>) {
+ FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Float &&
+ base_type <= reflection::BaseType::Double,
+ "Field must be a floating point type.");
+ if (field->optional() && !value->CheckField(field->offset())) {
+ inner_value = std::nullopt;
+ } else {
+ inner_value = std::make_optional(value->template GetField<T>(
+ field->offset(), field->default_real()));
+ }
+ } else if constexpr (std::is_same_v<T, std::string>) {
+ FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::String,
+ "Field must be a string type.");
+ if (!value->CheckField(field->offset())) {
+ inner_value = std::nullopt;
+ } else {
+ inner_value = std::make_optional(
+ value->template GetPointer<flatbuffers::String*>(field->offset())
+ ->str());
+ }
+ } else if constexpr (std::is_same_v<T, FlatbuffersTableTag>) {
+ auto sub_object = self.schema_->objects()->Get(field->type()->index());
+ FUZZTEST_INTERNAL_CHECK(
+ base_type == reflection::BaseType::Obj && !sub_object->is_struct(),
+ "Field must be a table type.");
+ inner_value = value->template GetPointer<const flatbuffers::Table*>(
+ field->offset());
+ } else if constexpr (is_flatbuffers_vector_tag_v<T>) {
+ FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::Vector ||
+ base_type == reflection::BaseType::Vector64,
+ "Field must be a vector type.");
+ if (!value->CheckField(field->offset())) {
+ inner_value = std::nullopt;
+ } else {
+ VisitVector<typename T::value_type, std::decay_t<decltype(domain)>>(
+ field, inner_value);
+ }
+ } else if constexpr (std::is_same_v<T, FlatbuffersUnionTag>) {
+ constexpr char kUnionTypeFieldSuffix[] = "_type";
+ auto enumdef = self.schema_->enums()->Get(field->type()->index());
+ auto type_field = self.table_object_->fields()->LookupByKey(
+ (field->name()->str() + kUnionTypeFieldSuffix).c_str());
+ if (type_field == nullptr) {
+ return;
+ }
+ auto union_type =
+ value->template GetField<uint8_t>(type_field->offset(), 0);
+ auto enumval = enumdef->values()->LookupByKey(union_type);
+ auto union_object =
+ self.schema_->objects()->Get(enumval->union_type()->index());
+ if (union_object->is_struct()) {
+ // TODO: (b/405939014) support structs in unions.
+ } else {
+ auto union_value =
+ value->template GetPointer<flatbuffers::Table*>(field->offset());
+ inner_value = std::make_pair(union_type, union_value);
+ }
+ }
+
+ auto inner = domain.FromValue(inner_value);
+ if (inner) {
+ out[field->id()] = *std::move(inner);
+ }
+ };
+
+ template <typename ElementType, typename Domain>
+ void VisitVector(const reflection::Field* field,
+ value_type_t<Domain>& inner_value) const {
+ if constexpr (std::is_integral_v<ElementType> ||
+ std::is_floating_point_v<ElementType>) {
+ auto vec =
+ value->template GetPointer<flatbuffers::Vector<ElementType>*>(
+ field->offset());
+ inner_value = std::make_optional(std::vector<ElementType>());
+ inner_value->reserve(vec->size());
+ for (auto i = 0; i < vec->size(); ++i) {
+ inner_value->push_back(vec->Get(i));
+ }
+ } else if constexpr (is_flatbuffers_enum_tag_v<ElementType>) {
+ using Underlaying = typename ElementType::type;
+ auto vec =
+ value->template GetPointer<flatbuffers::Vector<Underlaying>*>(
+ field->offset());
+ inner_value = std::make_optional(std::vector<Underlaying>());
+ inner_value->reserve(vec->size());
+ for (auto i = 0; i < vec->size(); ++i) {
+ inner_value->push_back(vec->Get(i));
+ }
+ } else if constexpr (std::is_same_v<ElementType, std::string>) {
+ auto vec = value->template GetPointer<
+ flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ field->offset());
+ inner_value = std::make_optional(std::vector<std::string>());
+ inner_value->reserve(vec->size());
+ for (auto i = 0; i < vec->size(); ++i) {
+ inner_value->push_back(vec->Get(i)->str());
+ }
+ } else if constexpr (std::is_same_v<ElementType, FlatbuffersTableTag>) {
+ auto vec = value->template GetPointer<
+ flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>*>(
+ field->offset());
+ inner_value =
+ std::make_optional(std::vector<const flatbuffers::Table*>());
+ inner_value->reserve(vec->size());
+ for (auto i = 0; i < vec->size(); ++i) {
+ inner_value->push_back(vec->Get(i));
+ }
+ } else if constexpr (std::is_same_v<ElementType, FlatbuffersUnionTag>) {
+ constexpr char kUnionTypeFieldSuffix[] = "_type";
+ auto type_field = self.table_object_->fields()->LookupByKey(
+ (field->name()->str() + kUnionTypeFieldSuffix).c_str());
+ if (type_field == nullptr) {
+ return;
+ }
+ auto type_vec =
+ value->template GetPointer<flatbuffers::Vector<uint8_t>*>(
+ type_field->offset());
+ auto value_vec = value->template GetPointer<
+ flatbuffers::Vector<flatbuffers::Offset<void>>*>(field->offset());
+ inner_value = std::make_optional(
+ typename std::decay_t<decltype(inner_value)>::value_type{});
+ inner_value->reserve(value_vec->size());
+ for (auto i = 0; i < value_vec->size(); ++i) {
+ inner_value->push_back(
+ std::make_pair(type_vec->Get(i), value_vec->Get(i)));
+ }
+ }
+ }
+ };
+
+ struct TableFieldBuilderVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ flatbuffers::FlatBufferBuilder& builder;
+ absl::flat_hash_map<typename corpus_type::key_type, flatbuffers::uoffset_t>&
+ offsets;
+ const typename corpus_type::value_type::second_type& corpus_value;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) const {
+ if constexpr (std::is_same_v<T, std::string>) {
+ auto& domain = self.GetSubDomain<T>(field);
+ auto user_value = domain.GetValue(corpus_value);
+ if (user_value.has_value()) {
+ auto offset =
+ builder.CreateString(user_value->data(), user_value->size()).o;
+ offsets.insert({field->id(), offset});
+ }
+ } else if constexpr (std::is_same_v<T, FlatbuffersTableTag>) {
+ FlatbuffersTableUntypedDomainImpl inner_domain(
+ self.schema_, self.schema_->objects()->Get(field->type()->index()));
+ auto opt_corpus = corpus_value.template GetAs<
+ std::variant<std::monostate, fuzztest::GenericDomainCorpusType>>();
+ if (std::holds_alternative<fuzztest::GenericDomainCorpusType>(
+ opt_corpus)) {
+ auto inner_corpus =
+ std::get<fuzztest::GenericDomainCorpusType>(opt_corpus)
+ .template GetAs<corpus_type>();
+ auto offset = inner_domain.BuildTable(inner_corpus, builder);
+ offsets.insert({field->id(), offset});
+ }
+ } else if constexpr (is_flatbuffers_vector_tag_v<T>) {
+ VisitVector<typename T::value_type>(field, self.GetSubDomain<T>(field));
+ } else if constexpr (std::is_same_v<T, FlatbuffersUnionTag>) {
+ const reflection::Enum* union_type =
+ self.schema_->enums()->Get(field->type()->index());
+ FlatbuffersUnionDomainImpl inner_domain{self.schema_, union_type};
+ auto opt_corpus = corpus_value.template GetAs<
+ std::variant<std::monostate, fuzztest::GenericDomainCorpusType>>();
+ if (std::holds_alternative<fuzztest::GenericDomainCorpusType>(
+ opt_corpus)) {
+ auto inner_corpus =
+ std::get<fuzztest::GenericDomainCorpusType>(opt_corpus)
+ .template GetAs<corpus_type_t<decltype(inner_domain)>>();
+ auto offset = inner_domain.BuildValue(inner_corpus, builder);
+ if (offset.has_value()) {
+ offsets.insert({field->id(), *offset});
+ }
+ }
+ }
+ }
+
+ private:
+ template <typename Element, typename Domain>
+ void VisitVector(const reflection::Field* field,
+ const Domain& domain) const {
+ if constexpr (std::is_integral_v<Element> ||
+ std::is_floating_point_v<Element>) {
+ auto value = domain.GetValue(corpus_value);
+ if (!value) {
+ return;
+ }
+ offsets.insert({field->id(), builder.CreateVector(*value).o});
+ } else if constexpr (is_flatbuffers_enum_tag_v<Element>) {
+ auto value = domain.GetValue(corpus_value);
+ if (!value) {
+ return;
+ }
+ offsets.insert({field->id(), builder.CreateVector(*value).o});
+ }
+ if constexpr (std::is_same_v<Element, FlatbuffersTableTag>) {
+ FlatbuffersTableUntypedDomainImpl domain(
+ self.schema_, self.schema_->objects()->Get(field->type()->index()));
+ auto opt_corpus = corpus_value.template GetAs<
+ std::variant<std::monostate, fuzztest::GenericDomainCorpusType>>();
+ if (std::holds_alternative<std::monostate>(opt_corpus)) {
+ return;
+ }
+ auto container_corpus =
+ std::get<fuzztest::GenericDomainCorpusType>(opt_corpus)
+ .template GetAs<std::list<corpus_type>>();
+ std::vector<flatbuffers::Offset<flatbuffers::Table>> vec_offsets;
+ for (auto& inner_corpus : container_corpus) {
+ auto offset = domain.BuildTable(inner_corpus, builder);
+ vec_offsets.push_back(offset);
+ }
+ offsets.insert({field->id(), builder.CreateVector(vec_offsets).o});
+ } else if constexpr (std::is_same_v<Element, std::string>) {
+ auto value = domain.GetValue(corpus_value);
+ if (!value) {
+ return;
+ }
+ std::vector<flatbuffers::Offset<flatbuffers::String>> vec_offsets;
+ for (const auto& str : *value) {
+ auto offset = builder.CreateString(str);
+ vec_offsets.push_back(offset);
+ }
+ offsets.insert({field->id(), builder.CreateVector(vec_offsets).o});
+ } else if constexpr (std::is_same_v<Element, FlatbuffersUnionTag>) {
+ const reflection::Enum* union_type =
+ self.schema_->enums()->Get(field->type()->index());
+ FlatbuffersUnionDomainImpl domain{self.schema_, union_type};
+ constexpr char kUnionTypeFieldSuffix[] = "_type";
+ const reflection::Field* type_field =
+ self.table_object_->fields()->LookupByKey(
+ (field->name()->str() + kUnionTypeFieldSuffix).c_str());
+
+ auto opt_corpus = corpus_value.template GetAs<
+ std::variant<std::monostate, fuzztest::GenericDomainCorpusType>>();
+ if (std::holds_alternative<std::monostate>(opt_corpus)) {
+ return;
+ }
+ auto container_corpus =
+ std::get<fuzztest::GenericDomainCorpusType>(opt_corpus)
+ .template GetAs<std::list<corpus_type_t<decltype(domain)>>>();
+
+ std::vector<
+ typename value_type_t<std::decay_t<decltype(domain)>>::first_type>
+ vec_types;
+ std::vector<flatbuffers::Offset<flatbuffers::Table>> vec_offsets;
+ for (auto& inner_corpus : container_corpus) {
+ auto offset = domain.BuildValue(inner_corpus, builder);
+ if (offset.has_value()) {
+ vec_offsets.push_back(*offset);
+ vec_types.push_back(domain.GetType(inner_corpus));
+ }
+ }
+ offsets.insert({field->id(), builder.CreateVector(vec_offsets).o});
+ offsets.insert({type_field->id(), builder.CreateVector(vec_types).o});
+ }
+ }
+ };
+
+ struct TableBuilderVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ flatbuffers::FlatBufferBuilder& builder;
+ absl::flat_hash_map<typename corpus_type::key_type, flatbuffers::uoffset_t>&
+ offsets;
+ const typename corpus_type::value_type::second_type& corpus_value;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) const {
+ auto size = flatbuffers::GetTypeSize(field->type()->base_type());
+ if constexpr (std::is_integral_v<T> || std::is_floating_point_v<T> ||
+ is_flatbuffers_enum_tag_v<T>) {
+ auto& domain = self.GetSubDomain<T>(field);
+ auto v = domain.GetValue(corpus_value);
+ if (!v) {
+ return;
+ }
+ builder.Align(size);
+ builder.PushBytes(reinterpret_cast<const uint8_t*>(&v), size);
+ builder.TrackField(field->offset(), builder.GetSize());
+ } else if constexpr (std::is_same_v<T, std::string> ||
+ is_flatbuffers_vector_tag_v<T>) {
+ if constexpr (is_flatbuffers_vector_tag_v<T>) {
+ if constexpr (std::is_same_v<typename T::value_type,
+ FlatbuffersUnionTag>) {
+ constexpr char kUnionTypeFieldSuffix[] = "_type";
+ const reflection::Field* type_field =
+ self.table_object_->fields()->LookupByKey(
+ (field->name()->str() + kUnionTypeFieldSuffix).c_str());
+ if (auto it = offsets.find(type_field->id()); it != offsets.end()) {
+ builder.AddOffset(type_field->offset(),
+ flatbuffers::Offset<>(it->second));
+ }
+ }
+ }
+ if (auto it = offsets.find(field->id()); it != offsets.end()) {
+ builder.AddOffset(
+ field->offset(),
+ flatbuffers::Offset<flatbuffers::String>(it->second));
+ }
+ } else if constexpr (std::is_same_v<T, FlatbuffersTableTag>) {
+ if (auto it = offsets.find(field->id()); it != offsets.end()) {
+ builder.AddOffset(
+ field->offset(),
+ flatbuffers::Offset<flatbuffers::Table>(it->second));
+ }
+ } else if constexpr (std::is_same_v<T, FlatbuffersUnionTag>) {
+ const reflection::Enum* union_type =
+ self.schema_->enums()->Get(field->type()->index());
+ FlatbuffersUnionDomainImpl domain(self.schema_, union_type);
+ if (auto it = offsets.find(field->id()); it != offsets.end()) {
+ builder.AddOffset(field->offset(),
+ flatbuffers::Offset<void>(it->second));
+
+ constexpr char kUnionTypeFieldSuffix[] = "_type";
+ const reflection::Field* type_field =
+ self.table_object_->fields()->LookupByKey(
+ (field->name()->str() + kUnionTypeFieldSuffix).c_str());
+ auto opt_corpus = corpus_value.template GetAs<std::variant<
+ std::monostate, fuzztest::GenericDomainCorpusType>>();
+ if (std::holds_alternative<std::monostate>(opt_corpus)) {
+ return;
+ }
+ auto inner_corpus =
+ std::get<fuzztest::GenericDomainCorpusType>(opt_corpus)
+ .template GetAs<corpus_type_t<decltype(domain)>>();
+ auto type_value = domain.GetType(inner_corpus);
+ auto size = flatbuffers::GetTypeSize(type_field->type()->base_type());
+ builder.Align(size);
+ builder.PushBytes(reinterpret_cast<const uint8_t*>(&type_value),
+ size);
+ builder.TrackField(type_field->offset(), builder.GetSize());
+ }
+ }
+ }
+ };
+
+ struct ParseVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ const IRObject& obj;
+ std::optional<GenericDomainCorpusType>& out;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) {
+ out = self.GetSubDomain<T>(field).ParseCorpus(obj);
+ }
+ };
+
+ struct ValidateVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ const GenericDomainCorpusType& corpus_value;
+ absl::Status& out;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) {
+ auto& domain = self.GetSubDomain<T>(field);
+ out = domain.ValidateCorpusValue(corpus_value);
+ if (!out.ok()) {
+ out = Prefix(out, absl::StrCat("Invalid value for field ",
+ field->name()->str()));
+ }
+ }
+ };
+
+ struct InitializeVisitor {
+ FlatbuffersTableUntypedDomainImpl& self;
+ absl::BitGenRef prng;
+ corpus_type& val;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) {
+ auto& domain = self.GetSubDomain<T>(field);
+ val[field->id()] = domain.Init(prng);
+ }
+ };
+
+ struct MutateVisitor {
+ FlatbuffersTableUntypedDomainImpl& self;
+ absl::BitGenRef prng;
+ const domain_implementor::MutationMetadata& metadata;
+ bool only_shrink;
+ corpus_type& val;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) {
+ auto& domain = self.GetSubDomain<T>(field);
+ if (auto it = val.find(field->id()); it != val.end()) {
+ domain.Mutate(it->second, prng, metadata, only_shrink);
+ } else if (!only_shrink) {
+ val[field->id()] = domain.Init(prng);
+ }
+ }
+ };
+
+ struct Printer {
+ const FlatbuffersTableUntypedDomainImpl& self;
+
+ void PrintCorpusValue(const corpus_type& value,
+ domain_implementor::RawSink out,
+ domain_implementor::PrintMode mode) const {
+ absl::Format(out, "{");
+ bool first = true;
+ for (const auto& [id, field_corpus] : value) {
+ if (!first) {
+ absl::Format(out, ", ");
+ }
+ absl::Nullable<const reflection::Field*> field = self.GetFieldById(id);
+ if (field == nullptr) {
+ absl::Format(out, "<unknown field: %d>", id);
+ } else {
+ VisitFlatbufferField(self.schema_, field,
+ PrinterVisitor{self, field_corpus, out, mode});
+ }
+ first = false;
+ }
+ absl::Format(out, "}");
+ }
+ };
+
+ struct PrinterVisitor {
+ const FlatbuffersTableUntypedDomainImpl& self;
+ const GenericDomainCorpusType& val;
+ domain_implementor::RawSink out;
+ domain_implementor::PrintMode mode;
+
+ template <typename T>
+ void Visit(const reflection::Field* absl_nonnull field) const {
+ auto& domain = self.GetSubDomain<T>(field);
+ absl::Format(out, "%s: ", field->name()->str());
+ domain_implementor::PrintValue(domain, val, out, mode);
+ }
+ };
+};
+
+// Domain implementation for flatbuffers generated table classes.
+template <typename T>
+class FlatbuffersTableDomainImpl
+ : public domain_implementor::DomainBase<
+ FlatbuffersTableDomainImpl<T>, const T* absl_nonnull,
+ std::pair<corpus_type_t<FlatbuffersTableUntypedDomainImpl>,
+ std::vector<uint8_t>>> {
+ public:
+ static_assert(
+ Requires<const T*>([](auto) -> decltype(T::BinarySchema::data()) {}),
+ "The flatbuffers generated class must be generated with the "
+ "`--bfbs-gen-embed` flag.");
+ static_assert(
+ Requires<const T*>([](auto) -> decltype(T::GetFullyQualifiedName()) {}),
+ "The flatbuffers generated class must be generated with the "
+ "`--gen-name-strings` flag.");
+
+ using typename FlatbuffersTableDomainImpl::DomainBase::corpus_type;
+ using typename FlatbuffersTableDomainImpl::DomainBase::value_type;
+
+ FlatbuffersTableDomainImpl() {
+ flatbuffers::Verifier verifier(T::BinarySchema::data(),
+ T::BinarySchema::size());
+ FUZZTEST_INTERNAL_CHECK(reflection::VerifySchemaBuffer(verifier),
+ "Invalid schema for flatbuffers table.");
+ auto schema = reflection::GetSchema(T::BinarySchema::data());
+ auto table_object =
+ schema->objects()->LookupByKey(T::GetFullyQualifiedName());
+ inner_ = FlatbuffersTableUntypedDomainImpl{schema, table_object};
+ }
+
+ FlatbuffersTableDomainImpl(const FlatbuffersTableDomainImpl& other)
+ : inner_(other.inner_) {
+ builder_.Clear();
+ }
+
+ FlatbuffersTableDomainImpl& operator=(
+ const FlatbuffersTableDomainImpl& other) {
+ if (this == &other) return *this;
+ inner_ = other.inner_;
+ builder_.Clear();
+ return *this;
+ }
+
+ FlatbuffersTableDomainImpl(FlatbuffersTableDomainImpl&& other)
+ : inner_(std::move(other.inner_)) {
+ builder_.Clear();
+ }
+
+ FlatbuffersTableDomainImpl& operator=(FlatbuffersTableDomainImpl&& other) {
+ if (this == &other) return *this;
+ inner_ = std::move(other.inner_);
+ builder_.Clear();
+ return *this;
+ }
+
+ // Initializes the table with random values.
+ corpus_type Init(absl::BitGenRef prng) {
+ if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
+ auto val = inner_->Init(prng);
+ auto offset = inner_->BuildTable(val, builder_);
+ builder_.Finish(flatbuffers::Offset<flatbuffers::Table>(offset));
+ auto buffer =
+ std::vector<uint8_t>(builder_.GetBufferPointer(),
+ builder_.GetBufferPointer() + builder_.GetSize());
+ builder_.Clear();
+ return std::make_pair(val, std::move(buffer));
+ }
+
+ // Returns the number of fields in the table.
+ uint64_t CountNumberOfFields(corpus_type& val) {
+ return inner_->CountNumberOfFields(val.first);
+ }
+
+ // Mutates the given corpus value.
+ void Mutate(corpus_type& val, absl::BitGenRef prng,
+ const domain_implementor::MutationMetadata& metadata,
+ bool only_shrink) {
+ inner_->Mutate(val.first, prng, metadata, only_shrink);
+ val.second = BuildBuffer(val.first);
+ }
+
+ // Returns the parsed corpus value.
+ value_type GetValue(const corpus_type& value) const {
+ return flatbuffers::GetRoot<T>(value.second.data());
+ }
+
+ // Returns the parsed corpus value.
+ std::optional<corpus_type> FromValue(const value_type& value) const {
+ auto val = inner_->FromValue((const flatbuffers::Table*)value);
+ if (!val.has_value()) return std::nullopt;
+ return std::make_optional(std::make_pair(*val, BuildBuffer(*val)));
+ }
+
+ // Returns the printer for the table.
+ auto GetPrinter() const { return Printer{*inner_}; }
+
+ // Returns the parsed corpus value.
+ std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
+ auto val = inner_->ParseCorpus(obj);
+ if (!val.has_value()) return std::nullopt;
+ return std::make_optional(std::make_pair(*val, BuildBuffer(*val)));
+ }
+
+ // Returns the serialized corpus value.
+ IRObject SerializeCorpus(const corpus_type& corpus_value) const {
+ return inner_->SerializeCorpus(corpus_value.first);
+ }
+
+ // Returns the status of the given corpus value.
+ absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
+ return inner_->ValidateCorpusValue(corpus_value.first);
+ }
+
+ private:
+ std::optional<FlatbuffersTableUntypedDomainImpl> inner_;
+ mutable flatbuffers::FlatBufferBuilder builder_;
+
+ struct Printer {
+ const FlatbuffersTableUntypedDomainImpl& inner;
+
+ void PrintCorpusValue(const corpus_type& value,
+ domain_implementor::RawSink out,
+ domain_implementor::PrintMode mode) const {
+ inner.GetPrinter().PrintCorpusValue(value.first, out, mode);
+ }
+ };
+
+ std::vector<uint8_t> BuildBuffer(
+ const typename corpus_type::first_type& val) const {
+ auto offset = inner_->BuildTable(val, builder_);
+ builder_.Finish(flatbuffers::Offset<flatbuffers::Table>(offset));
+ auto buffer =
+ std::vector<uint8_t>(builder_.GetBufferPointer(),
+ builder_.GetBufferPointer() + builder_.GetSize());
+ builder_.Clear();
+ return buffer;
+ }
+};
+
+template <typename T>
+class ArbitraryImpl<T, std::enable_if_t<is_flatbuffers_table_v<T>>>
+ : public FlatbuffersTableDomainImpl<T> {};
+} // namespace fuzztest::internal
+#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_
diff --git a/fuzztest/internal/meta.h b/fuzztest/internal/meta.h
index 4ddada1..83e84f6 100644
--- a/fuzztest/internal/meta.h
+++ b/fuzztest/internal/meta.h
@@ -200,6 +200,24 @@
inline constexpr bool is_protocol_buffer_enum_v =
IsProtocolBufferEnumImpl<T>(true);
+template <typename T, typename = void>
+inline constexpr bool is_flatbuffers_table_v = false;
+
+// Flatbuffers tables generated structs do not have a public base class, so we
+// check for a few specific methods:
+// - T is a struct.
+// - T has a `Builder` type.
+// - T has a `BinarySchema` type with a `data()` method (only available when
+// passing `--bfbs-gen-embed` to the flatbuffer compiler)
+// - T has a static method called `GetFullyQualifiedName` (only available when
+// passing `--gen-name-strings` to the flatbuffer compiler).
+template <typename T>
+inline constexpr bool
+ is_flatbuffers_table_v<T, std::enable_if_t<std::is_class_v<T>>> =
+ Requires<const T*>([](auto) -> typename T::Builder {}) &&
+ Requires<const T*>([](auto) -> decltype(T::BinarySchema::data()) {}) &&
+ Requires<const T*>([](auto) -> decltype(T::GetFullyQualifiedName()) {});
+
template <typename T>
inline constexpr bool has_size_v =
Requires<T>([](auto v) -> decltype(v.size()) {});
diff --git a/fuzztest/internal/test_flatbuffers.fbs b/fuzztest/internal/test_flatbuffers.fbs
new file mode 100644
index 0000000..0efb9eb
--- /dev/null
+++ b/fuzztest/internal/test_flatbuffers.fbs
@@ -0,0 +1,54 @@
+namespace fuzztest.internal;
+
+enum TestFbsEnum: byte {
+ First,
+ Second,
+ Third
+}
+
+table SimpleTestFbsTable {
+ b: bool;
+ f: float;
+ str: string;
+ e: TestFbsEnum;
+}
+
+table NestedTestFbsTable {
+ t: SimpleTestFbsTable;
+}
+
+table OptionalRequiredTestFbsTable {
+ def_scalar: bool = true;
+ opt_scalar: bool = null;
+ req_str: string (required);
+ opt_str: string;
+}
+
+table VectorsTestFbsTable {
+ b: [bool];
+ i8: [byte];
+ i16: [short];
+ i32: [int];
+ i64: [long];
+ u8: [ubyte];
+ u16: [ushort];
+ u32: [uint];
+ u64: [ulong];
+ f: [float];
+ d: [double];
+ str: [string];
+ e: [TestFbsEnum];
+ t: [SimpleTestFbsTable];
+}
+
+union Union {
+ OptionalRequiredTestFbsTable,
+ SimpleTestFbsTable,
+}
+
+table UnionTestFbsTable {
+ u: Union;
+ u_vec: [Union];
+}
+
+root_type SimpleTestFbsTable;