No public description
PiperOrigin-RevId: 554371233
diff --git a/.github/workflows/bazel_test.yml b/.github/workflows/bazel_test.yml
index eca5cf6..580e546 100644
--- a/.github/workflows/bazel_test.yml
+++ b/.github/workflows/bazel_test.yml
@@ -28,7 +28,7 @@
run_tests:
name: Run tests
runs-on: ubuntu-latest
- timeout-minutes: 30
+ timeout-minutes: 60
strategy:
matrix:
config: ['default', 'fuzztest']
@@ -62,10 +62,10 @@
if: matrix.config == 'default'
run: |
bazel test --build_tests_only --test_output=errors \
- -c ${{ matrix.compilation_mode }} -- //... -//centipede/...
+ -c ${{ matrix.compilation_mode }} --//rpc_fuzzing:enable_rpc_fuzzing -- //... -//centipede/...
- name: Run end-to-end tests with --config=fuzztest
if: matrix.config == 'fuzztest'
run: |
bazel test --build_tests_only --test_output=errors \
- -c ${{ matrix.compilation_mode }} --config=fuzztest //e2e_tests:all
+ -c ${{ matrix.compilation_mode }} --config=fuzztest //e2e_tests:all --//rpc_fuzzing:enable_rpc_fuzzing
diff --git a/.github/workflows/cmake_test.yml b/.github/workflows/cmake_test.yml
index 55bc145..257b8ca 100644
--- a/.github/workflows/cmake_test.yml
+++ b/.github/workflows/cmake_test.yml
@@ -26,7 +26,7 @@
run_tests:
name: Run CMake tests
runs-on: ubuntu-latest
- timeout-minutes: 30
+ timeout-minutes: 60
strategy:
matrix:
mode: ['default', 'fuzzing']
diff --git a/WORKSPACE b/WORKSPACE
index b685cdc..d8a02a2 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -36,9 +36,9 @@
http_archive(
name = "com_google_absl",
- sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21",
- strip_prefix = "abseil-cpp-20230125.0",
- url = "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz",
+ sha256 = "5366d7e7fa7ba0d915014d387b66d0d002c03236448e1ba9ef98122c13b35c36",
+ strip_prefix = "abseil-cpp-20230125.3",
+ url = "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.3.tar.gz",
)
http_archive(
@@ -55,6 +55,18 @@
url = "https://www.antlr.org/download/antlr4-cpp-runtime-4.12.0-source.zip",
)
+http_archive(
+ name = "com_github_grpc_grpc",
+ urls = [
+ "https://github.com/grpc/grpc/archive/49dcbce9aba5f5f0959871354df90e38f0c5ed00.tar.gz", # v1.55.3
+ ],
+ strip_prefix = "grpc-49dcbce9aba5f5f0959871354df90e38f0c5ed00",
+)
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+grpc_deps()
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+grpc_extra_deps()
+
################################################################################
# Direct dependencies that are only required for running tests
################################################################################
diff --git a/domain_tests/BUILD b/domain_tests/BUILD
index 4e24aa6..28c201f 100644
--- a/domain_tests/BUILD
+++ b/domain_tests/BUILD
@@ -22,7 +22,10 @@
name = "domain_testing",
testonly = 1,
hdrs = ["domain_testing.h"],
- visibility = ["@com_google_fuzztest//fuzztest:__subpackages__"],
+ visibility = [
+ "@com_google_fuzztest//fuzztest:__subpackages__",
+ "@com_google_fuzztest//rpc_fuzzing:__pkg__",
+ ],
deps = [
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
diff --git a/e2e_tests/BUILD b/e2e_tests/BUILD
index b146dd3..3df931e 100644
--- a/e2e_tests/BUILD
+++ b/e2e_tests/BUILD
@@ -26,6 +26,7 @@
srcs = ["functional_test.cc"],
data = [
"@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_for_functional_testing.stripped",
+ "@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_for_rpc_session_grpc_functional_testing.stripped",
"@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_with_invalid_seeds.stripped",
],
shard_count = 50,
diff --git a/e2e_tests/functional_test.cc b/e2e_tests/functional_test.cc
index 070451e..e54fca8 100644
--- a/e2e_tests/functional_test.cc
+++ b/e2e_tests/functional_test.cc
@@ -1273,6 +1273,28 @@
EXPECT_EQ(1, CountSubstrs(std_err, "<<FixtureTest::~FixtureTest()>>"));
}
+TEST_F(FuzzingModeTest, RpcSessionFuzzingSupportsFuzzingGrpcService) {
+ auto [status, std_out, std_err] = RunBinaryWith(
+ BinaryPath("testdata/fuzz_tests_for_rpc_session_grpc_functional_testing"),
+ "--fuzz=MiniBloggerGrpcTest."
+ "ServiceDoesNotCrashWithAnyRpcSequence",
+ {}, /*timeout=*/absl::Seconds(60));
+ EXPECT_THAT(std_err, Not(HasSubstr("Failed to execute !")));
+ EXPECT_THAT(std_err, HasSubstr("Using an inactive session id!"));
+ EXPECT_THAT(status, Eq(Signal(SIGABRT)));
+}
+
+TEST_F(FuzzingModeTest, RpcSessionOfForGrpcFindsStatefulBugInFuzzingMode) {
+ auto [status, std_out, std_err] = RunBinaryWith(
+ BinaryPath("testdata/fuzz_tests_for_rpc_session_grpc_functional_testing"),
+ "--fuzz=MiniBloggerGrpcTest."
+ "TestRpcSessionOfSetup",
+ {{"FUZZTEST_MAX_FUZZING_RUNS", "-1"}}, /*timeout=*/absl::Seconds(60));
+ EXPECT_THAT(std_err, Not(HasSubstr("Failed to execute !")));
+ EXPECT_THAT(std_err, HasSubstr("Using an inactive session id!"));
+ EXPECT_THAT(status, Eq(Signal(SIGABRT)));
+}
+
TEST_F(FuzzingModeTest,
GoogleTestPerIterationFixtureInstantiatedOncePerIteration) {
auto [status, std_out, std_err] = RunWith(
diff --git a/e2e_tests/testdata/BUILD b/e2e_tests/testdata/BUILD
index cd3896e..6cc5059 100644
--- a/e2e_tests/testdata/BUILD
+++ b/e2e_tests/testdata/BUILD
@@ -59,6 +59,26 @@
],
)
+cc_binary(
+ name = "fuzz_tests_for_rpc_session_grpc_functional_testing",
+ testonly = 1,
+ srcs = [
+ "fuzz_tests_for_rpc_session_grpc_functional_testing.cc",
+ ],
+ deps = [
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//fuzztest",
+ "@com_google_fuzztest//fuzztest:fuzztest_gtest_main",
+ "@com_google_fuzztest//rpc_fuzzing:grpc_stub",
+ "@com_google_fuzztest//rpc_fuzzing:rpc_session",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_grpc_service",
+ ],
+)
+
# Must be a separate binary, as the invalid seeds make the binary crash at
# process start time. Used by `functional_test` only.
cc_binary(
diff --git a/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc b/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc
new file mode 100644
index 0000000..fc4f295
--- /dev/null
+++ b/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc
@@ -0,0 +1,76 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Rpc session fuzz test examples to be used for e2e functional testing.
+//
+// Specifically, used by `functional_test` only.
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "./fuzztest/fuzztest.h"
+#include "./rpc_fuzzing/grpc_stub.h"
+#include "./rpc_fuzzing/rpc_session.h"
+#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h"
+#include "grpcpp//security/server_credentials.h"
+#include "grpcpp//server.h"
+#include "grpcpp//server_builder.h"
+#include "grpcpp//support/channel_arguments.h"
+
+namespace {
+
+std::string_view ServiceNameFactory() {
+ return "fuzztest.internal.MiniBlogger";
+}
+
+class MiniBloggerGrpcTest {
+ public:
+ MiniBloggerGrpcTest()
+ : server_(grpc::ServerBuilder()
+ .RegisterService(&mini_blogger_service_)
+ .BuildAndStart()),
+ stub_(server_->InProcessChannel(grpc::ChannelArguments())) {}
+
+ ~MiniBloggerGrpcTest() { server_->Shutdown(); }
+
+ void ServiceDoesNotCrashWithAnyRpcSequence(
+ fuzztest::RpcSequence rpc_sequence) {
+ absl::Status status = fuzztest::ExecuteRpcSequence(stub_, rpc_sequence);
+ if (!status.ok()) {
+ std::cerr << "Failed to execute !" << status.message() << "\n";
+ std::abort();
+ }
+ }
+
+ void TestRpcSessionOfSetup(fuzztest::RpcSequence rpc_sequence) {
+ ServiceDoesNotCrashWithAnyRpcSequence(rpc_sequence);
+ }
+
+ private:
+ fuzztest::internal::MiniBloggerGrpcService mini_blogger_service_;
+ std::unique_ptr<grpc::Server> server_ = nullptr;
+ fuzztest::GrpcStub stub_;
+};
+
+FUZZ_TEST_F(MiniBloggerGrpcTest, ServiceDoesNotCrashWithAnyRpcSequence)
+ .WithDomains(fuzztest::RpcSession<fuzztest::internal::MiniBlogger>());
+
+FUZZ_TEST_F(MiniBloggerGrpcTest, TestRpcSessionOfSetup)
+ .WithDomains(fuzztest::RpcSessionOf(ServiceNameFactory));
+
+} // namespace
diff --git a/fuzztest/BUILD b/fuzztest/BUILD
index 041e9cf..e75c0c9 100644
--- a/fuzztest/BUILD
+++ b/fuzztest/BUILD
@@ -128,6 +128,7 @@
visibility = [
"@com_google_fuzztest//centipede:__subpackages__",
"@com_google_fuzztest//domain_tests:__pkg__",
+ "@com_google_fuzztest//rpc_fuzzing:__pkg__",
],
deps = [
":absl_helpers",
diff --git a/rpc_fuzzing/BUILD b/rpc_fuzzing/BUILD
new file mode 100644
index 0000000..1940b24
--- /dev/null
+++ b/rpc_fuzzing/BUILD
@@ -0,0 +1,221 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = [
+ "@com_google_fuzztest//e2e_tests/testdata:__pkg__",
+ "@com_google_fuzztest//fuzztest:__subpackages__",
+])
+
+load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
+
+bool_flag(
+ name = "enable_rpc_fuzzing",
+ build_setting_default = False,
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "rpc_fuzzing_is_enabled",
+ flag_values = {
+ ":enable_rpc_fuzzing": "true",
+ },
+)
+
+cc_library(
+ name = "proto_field_path",
+ srcs = ["proto_field_path.cc"],
+ hdrs = ["proto_field_path.h"],
+ target_compatible_with = select({
+ ":rpc_fuzzing_is_enabled": [],
+ "//conditions:default": ["@platforms//:incompatible"],
+ }),
+ deps = [
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//fuzztest:logging",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "proto_field_path_test",
+ srcs = ["proto_field_path_test.cc"],
+ deps = [
+ ":proto_field_path",
+ "@com_google_fuzztest//fuzztest:test_protobuf_cc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "rpc_sequence",
+ srcs = ["rpc_sequence.cc"],
+ hdrs = ["rpc_sequence.h"],
+ deps = [
+ ":proto_field_path",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/random:bit_gen_ref",
+ "@com_google_fuzztest//fuzztest:logging",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "rpc_sequence_test",
+ srcs = ["rpc_sequence_test.cc"],
+ deps = [
+ ":proto_field_path",
+ ":rpc_sequence",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log:check",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+# TODO(changochen): Combine the stubs into a single lib.
+cc_library(
+ name = "rpc_stub",
+ hdrs = ["rpc_stub.h"],
+ deps = [
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "grpc_stub",
+ srcs = ["grpc_stub.cc"],
+ hdrs = ["grpc_stub.h"],
+ deps = [
+ ":rpc_stub",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_github_grpc_grpc//:grpc++_codegen_proto",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "rpc_executor",
+ srcs = ["rpc_executor.cc"],
+ hdrs = ["rpc_executor.h"],
+ deps = [
+ ":rpc_sequence",
+ ":rpc_stub",
+ "@com_google_absl//absl/log:die_if_null",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//fuzztest:logging",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "rpc_executor_test",
+ srcs = ["rpc_executor_test.cc"],
+ deps = [
+ ":grpc_stub",
+ ":proto_field_path",
+ ":rpc_executor",
+ ":rpc_sequence",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//domain_tests:domain_testing",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_grpc_service",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "rpc_potential_dfg",
+ srcs = ["rpc_potential_dfg.cc"],
+ hdrs = ["rpc_potential_dfg.h"],
+ deps = [
+ ":proto_field_path",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//fuzztest:logging",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "rpc_potential_dfg_test",
+ srcs = ["rpc_potential_dfg_test.cc"],
+ deps = [
+ ":proto_field_path",
+ ":rpc_potential_dfg",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "rpc_session",
+ hdrs = ["rpc_session.h"],
+ deps = [
+ ":proto_field_path",
+ ":rpc_executor",
+ ":rpc_potential_dfg",
+ ":rpc_sequence",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/random:bit_gen_ref",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_fuzztest//fuzztest:domain",
+ "@com_google_fuzztest//fuzztest:domain_core",
+ "@com_google_fuzztest//fuzztest:logging",
+ "@com_google_fuzztest//fuzztest:serialization",
+ "@com_google_fuzztest//fuzztest:table_of_recent_compares",
+ "@com_google_fuzztest//fuzztest:type_support",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "rpc_session_test",
+ srcs = ["rpc_session_test.cc"],
+ deps = [
+ ":proto_field_path",
+ ":rpc_potential_dfg",
+ ":rpc_sequence",
+ ":rpc_session",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/random",
+ "@com_google_fuzztest//domain_tests:domain_testing",
+ "@com_google_fuzztest//fuzztest:serialization",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto",
+ "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
diff --git a/rpc_fuzzing/grpc_stub.cc b/rpc_fuzzing/grpc_stub.cc
new file mode 100644
index 0000000..fa4a954
--- /dev/null
+++ b/rpc_fuzzing/grpc_stub.cc
@@ -0,0 +1,57 @@
+#include "./rpc_fuzzing/grpc_stub.h"
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/notification.h"
+#include "grpcpp//impl/proto_utils.h"
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "grpcpp//client_context.h"
+#include "grpcpp//generic/generic_stub.h"
+#include "grpcpp//support/status.h"
+#include "grpcpp//support/stub_options.h"
+
+namespace fuzztest {
+
+absl::StatusOr<std::unique_ptr<google::protobuf::Message>> GrpcStub::CallMethod(
+ const google::protobuf::MethodDescriptor& method_descriptor,
+ const google::protobuf::Message& request) {
+ grpc::ClientContext cli_ctx;
+ const google::protobuf::Message* response_prototype =
+ google::protobuf::MessageFactory::generated_factory()->GetPrototype(
+ method_descriptor.output_type());
+ if (!response_prototype) {
+ return absl::InternalError(
+ absl::StrCat("Cannot find prototype for ",
+ method_descriptor.output_type()->full_name(),
+ " in the generated proto MessageFactory"));
+ }
+ std::unique_ptr<google::protobuf::Message> response =
+ absl::WrapUnique(response_prototype->New());
+ absl::Notification notification;
+ absl::Status status = absl::OkStatus();
+ grpc_stub_->UnaryCall(
+ &cli_ctx,
+ absl::StrCat("/", method_descriptor.service()->full_name(), "/",
+ method_descriptor.name()),
+ grpc::StubOptions(), &request, response.get(),
+ [¬ification, &status](grpc::Status s) {
+ // TODO(changochen): Better conversion from grpc::Status to absl::Status.
+ if (!s.ok()) {
+ status = absl::InternalError(s.error_message());
+ }
+ notification.Notify();
+ });
+ notification.WaitForNotification();
+ if (status.ok()) {
+ return response;
+ } else {
+ return status;
+ }
+}
+
+} // namespace fuzztest
diff --git a/rpc_fuzzing/grpc_stub.h b/rpc_fuzzing/grpc_stub.h
new file mode 100644
index 0000000..1e046b0
--- /dev/null
+++ b/rpc_fuzzing/grpc_stub.h
@@ -0,0 +1,32 @@
+#ifndef FUZZTEST_RPC_FUZZING_GRPC_STUB_H_
+#define FUZZTEST_RPC_FUZZING_GRPC_STUB_H_
+#include <memory>
+
+#include "google/protobuf/message.h"
+#include "absl/status/statusor.h"
+#include "./rpc_fuzzing/rpc_stub.h"
+#include "grpcpp//generic/generic_stub.h"
+
+namespace fuzztest {
+
+using GrpcGenericStub =
+ grpc::TemplatedGenericStub<google::protobuf::Message, google::protobuf::Message>;
+
+class GrpcStub : public RpcStub {
+ public:
+ GrpcStub(std::shared_ptr<grpc::ChannelInterface> channel)
+ : grpc_stub_(
+ std::make_unique<
+ grpc::TemplatedGenericStub<google::protobuf::Message, google::protobuf::Message>>(
+ channel)) {}
+ absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod(
+ const google::protobuf::MethodDescriptor& method_descriptor,
+ const google::protobuf::Message& request) override;
+
+ private:
+ std::unique_ptr<GrpcGenericStub> grpc_stub_ = nullptr;
+};
+
+} // namespace fuzztest
+
+#endif // FUZZTEST_RPC_FUZZING_GRPC_STUB_H_
diff --git a/rpc_fuzzing/proto_field_path.cc b/rpc_fuzzing/proto_field_path.cc
new file mode 100644
index 0000000..b1d7b37
--- /dev/null
+++ b/rpc_fuzzing/proto_field_path.cc
@@ -0,0 +1,313 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "./rpc_fuzzing/proto_field_path.h"
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "./fuzztest/internal/logging.h"
+
+namespace fuzztest::internal {
+
+bool Contains(const google::protobuf::FieldDescriptor& parent,
+ const google::protobuf::FieldDescriptor& child) {
+ return parent.message_type() == child.containing_type();
+}
+
+bool operator==(const FieldPath& lhs, const FieldPath& rhs) {
+ return lhs.field_descriptors_ == rhs.field_descriptors_;
+}
+
+bool FieldPath::CanAppend(const google::protobuf::FieldDescriptor& field) const {
+ return field_descriptors_.empty() ||
+ Contains(*field_descriptors_.back(), field);
+}
+
+void FieldPath::AppendField(const google::protobuf::FieldDescriptor& field) {
+ FUZZTEST_INTERNAL_CHECK(
+ CanAppend(field),
+ "The current last field in the path must be a message containing the "
+ "appended field.");
+ field_descriptors_.push_back(&field);
+}
+
+void FieldPath::AppendPath(const FieldPath& other) {
+ if (other.field_descriptors_.empty()) {
+ return;
+ }
+ field_descriptors_.reserve(field_descriptors_.size() +
+ other.field_descriptors_.size());
+ FUZZTEST_INTERNAL_CHECK(
+ CanAppend(*other.GetAllFields()[0]),
+ "The current last field in the path must be a message containing the "
+ "first field in the other path.");
+ field_descriptors_.insert(field_descriptors_.end(),
+ other.field_descriptors_.begin(),
+ other.field_descriptors_.end());
+}
+
+std::string FieldPath::ToString() const {
+ return absl::StrJoin(
+ field_descriptors_, /*separator=*/".",
+ [](std::string* result, const google::protobuf::FieldDescriptor* field) {
+ absl::StrAppend(result, field->name());
+ });
+}
+
+FieldPath GetFieldPathWithDescriptor(const google::protobuf::Descriptor& descriptor,
+ std::string_view field_path_str) {
+ std::vector<std::string> parts = absl::StrSplit(field_path_str, '.');
+ FUZZTEST_INTERNAL_CHECK(!parts.empty(), "Invalid path string!");
+
+ FieldPath result;
+ const google::protobuf::Descriptor* descriptor_ptr = &descriptor;
+
+ for (const std::string& part : parts) {
+ const google::protobuf::FieldDescriptor* field =
+ descriptor_ptr->FindFieldByName(part);
+ FUZZTEST_INTERNAL_CHECK(field != nullptr, "Invalid field name!");
+ result.AppendField(*field);
+ descriptor_ptr = field->message_type();
+ }
+ return result;
+}
+
+const google::protobuf::Message* FieldPath::GetContainingMessageOfLastField(
+ const google::protobuf::Message& message) const {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(),
+ "Empty field path!");
+ const google::protobuf::Message* parent = &message;
+ for (size_t i = 0; i < field_descriptors_.size() - 1; ++i) {
+ if (field_descriptors_[i]->is_repeated()) {
+ if (parent->GetReflection()->FieldSize(*parent, field_descriptors_[i]) ==
+ 0) {
+ return nullptr;
+ }
+ parent = &(parent->GetReflection()->GetRepeatedMessage(
+ *parent, field_descriptors_[i], 0));
+ } else {
+ parent = &(
+ parent->GetReflection()->GetMessage(*parent, field_descriptors_[i]));
+ }
+ }
+ return parent;
+}
+
+google::protobuf::Message* FieldPath::MutableContainingMessageOfLastField(
+ google::protobuf::Message& message) const {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(),
+ "Empty field path!");
+ google::protobuf::Message* parent = &message;
+ for (size_t i = 0; i < field_descriptors_.size() - 1; ++i) {
+ if (field_descriptors_[i]->is_repeated()) {
+ if (parent->GetReflection()->FieldSize(*parent, field_descriptors_[i]) ==
+ 0) {
+ parent->GetReflection()->AddMessage(parent, field_descriptors_[i]);
+ }
+ parent = parent->GetReflection()->MutableRepeatedMessage(
+ parent, field_descriptors_[i], 0);
+ } else {
+ parent = parent->GetReflection()->MutableMessage(parent,
+ field_descriptors_[i]);
+ }
+ }
+ return parent;
+}
+
+const google::protobuf::FieldDescriptor& FieldPath::GetLastField() const {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(),
+ "Empty field path!");
+ return *field_descriptors_.back();
+}
+
+const std::vector<const google::protobuf::FieldDescriptor*>& FieldPath::GetAllFields()
+ const {
+ return field_descriptors_;
+}
+
+void CopyField(const FieldPath& from_field, const google::protobuf::Message& from,
+ const FieldPath& to_field, google::protobuf::Message& to) {
+ const google::protobuf::Message* from_inner_most_message =
+ from_field.GetContainingMessageOfLastField(from);
+ if (from_inner_most_message == nullptr) {
+ return;
+ }
+ const google::protobuf::FieldDescriptor& from_last_field = from_field.GetLastField();
+ google::protobuf::Message* to_inner_most_message =
+ to_field.MutableContainingMessageOfLastField(to);
+ if (to_inner_most_message == nullptr) {
+ return;
+ }
+ const google::protobuf::FieldDescriptor& to_last_field = to_field.GetLastField();
+
+ const google::protobuf::Reflection* from_refl =
+ from_inner_most_message->GetReflection();
+ const google::protobuf::Reflection* to_refl = to_inner_most_message->GetReflection();
+ FUZZTEST_INTERNAL_CHECK(from_last_field.type() == to_last_field.type(),
+ "Fields of mismatch types cannot be copied!");
+ // TODO(changochen): We might make this condition optional.
+ FUZZTEST_INTERNAL_CHECK(from_last_field.name() == to_last_field.name(),
+ "Fields of mismatch names cannot be copied!");
+ switch (from_last_field.type()) {
+#define HANDLE_TYPE(UPPERCASE, CAMEL) \
+ case google::protobuf::FieldDescriptor::TYPE_##UPPERCASE: \
+ if (from_last_field.is_repeated() && to_last_field.is_repeated()) { \
+ to_refl->ClearField(to_inner_most_message, &to_last_field); \
+ for (int i = 0; i < from_refl->FieldSize(*from_inner_most_message, \
+ &from_last_field); \
+ ++i) { \
+ to_refl->Add##CAMEL( \
+ to_inner_most_message, &to_last_field, \
+ from_refl->GetRepeated##CAMEL(*from_inner_most_message, \
+ &from_last_field, i)); \
+ } \
+ } else if (from_last_field.is_repeated()) { \
+ if (from_refl->FieldSize(*from_inner_most_message, &from_last_field) == \
+ 0) { \
+ to_refl->ClearField(to_inner_most_message, &to_last_field); \
+ } else { \
+ to_refl->Set##CAMEL( \
+ to_inner_most_message, &to_last_field, \
+ from_refl->GetRepeated##CAMEL(*from_inner_most_message, \
+ &from_last_field, 0)); \
+ } \
+ } else if (to_last_field.is_repeated()) { \
+ to_refl->ClearField(to_inner_most_message, &to_last_field); \
+ to_refl->Add##CAMEL( \
+ to_inner_most_message, &to_last_field, \
+ from_refl->Get##CAMEL(*from_inner_most_message, &from_last_field)); \
+ } else { \
+ to_refl->Set##CAMEL( \
+ to_inner_most_message, &to_last_field, \
+ from_refl->Get##CAMEL(*from_inner_most_message, &from_last_field)); \
+ } \
+ break;
+
+ HANDLE_TYPE(DOUBLE, Double);
+ HANDLE_TYPE(FLOAT, Float);
+ HANDLE_TYPE(INT64, Int64);
+ HANDLE_TYPE(UINT64, UInt64);
+ HANDLE_TYPE(INT32, Int32);
+ HANDLE_TYPE(FIXED64, UInt64);
+ HANDLE_TYPE(FIXED32, UInt32);
+ HANDLE_TYPE(BOOL, Bool);
+ HANDLE_TYPE(STRING, String);
+ HANDLE_TYPE(BYTES, String);
+ HANDLE_TYPE(UINT32, UInt32);
+ HANDLE_TYPE(ENUM, Enum);
+ HANDLE_TYPE(SFIXED64, Int64);
+ HANDLE_TYPE(SFIXED32, Int32);
+ HANDLE_TYPE(SINT64, Int64);
+ HANDLE_TYPE(SINT32, Int32);
+
+#undef HANDLE_TYPE
+ case google::protobuf::FieldDescriptor::TYPE_GROUP:
+ case google::protobuf::FieldDescriptor::TYPE_MESSAGE:
+ if (from_last_field.is_repeated() && to_last_field.is_repeated()) {
+ to_refl->ClearField(to_inner_most_message, &to_last_field);
+ for (int i = 0; i < from_refl->FieldSize(*from_inner_most_message,
+ &from_last_field);
+ ++i) {
+ to_refl->AddMessage(to_inner_most_message, &to_last_field)
+ ->CopyFrom(from_refl->GetRepeatedMessage(*from_inner_most_message,
+ &from_last_field, i));
+ }
+ } else if (from_last_field.is_repeated()) {
+ if (from_refl->FieldSize(*from_inner_most_message, &from_last_field) ==
+ 0) {
+ to_refl->ClearField(to_inner_most_message, &to_last_field);
+ } else {
+ to_refl->MutableMessage(to_inner_most_message, &to_last_field)
+ ->CopyFrom(from_refl->GetRepeatedMessage(*from_inner_most_message,
+ &from_last_field, 0));
+ }
+ } else if (to_last_field.is_repeated()) {
+ to_refl->ClearField(to_inner_most_message, &to_last_field);
+ to_refl->AddMessage(to_inner_most_message, &to_last_field)
+ ->CopyFrom(from_refl->GetMessage(*from_inner_most_message,
+ &from_last_field));
+ } else {
+ to_refl->MutableMessage(to_inner_most_message, &to_last_field)
+ ->CopyFrom(from_refl->GetMessage(*from_inner_most_message,
+ &from_last_field));
+ }
+ break;
+ default:
+ FUZZTEST_INTERNAL_CHECK(
+ false, absl::StrCat("Unexpected type ", from_last_field.type_name()));
+ }
+}
+
+std::vector<FieldPath> CollectAllFieldsImpl(
+ const google::protobuf::Descriptor& message_descriptor,
+ absl::flat_hash_set<const google::protobuf::Descriptor*>& visited_messages) {
+ std::vector<FieldPath> results;
+ if (visited_messages.contains(&message_descriptor)) {
+ return results;
+ }
+ visited_messages.insert(&message_descriptor);
+ for (size_t i = 0; i < message_descriptor.field_count(); ++i) {
+ const google::protobuf::FieldDescriptor& field = *message_descriptor.field(i);
+ FieldPath field_path;
+ field_path.AppendField(field);
+ results.push_back(field_path);
+ // `GROUP` field is a deprecated way of expressing inner message.
+ if (field.type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE ||
+ field.type() == google::protobuf::FieldDescriptor::TYPE_GROUP) {
+ const google::protobuf::Descriptor& inner = *field.message_type();
+ std::vector<FieldPath> inner_fields =
+ CollectAllFieldsImpl(inner, visited_messages);
+ for (const auto& loc : inner_fields) {
+ FieldPath inner_field_path(field_path);
+ inner_field_path.AppendPath(loc);
+ results.push_back(inner_field_path);
+ }
+ }
+ }
+ return results;
+}
+
+std::vector<FieldPath> CollectAllFields(
+ const google::protobuf::Descriptor& message_descriptor) {
+ absl::flat_hash_set<const google::protobuf::Descriptor*> visited_messages;
+ return CollectAllFieldsImpl(message_descriptor, visited_messages);
+}
+
+bool AreDifferentFieldsInSameOneOf(const google::protobuf::FieldDescriptor& a,
+ const google::protobuf::FieldDescriptor& b) {
+ if (&a == &b || a.containing_oneof() == nullptr ||
+ b.containing_oneof() == nullptr)
+ return false;
+ return a.containing_oneof() == b.containing_oneof();
+}
+
+bool AreOneOfAltearnatives(const FieldPath& a, const FieldPath& b) {
+ const std::vector<const google::protobuf::FieldDescriptor*>& a_fields =
+ a.GetAllFields();
+ const std::vector<const google::protobuf::FieldDescriptor*>& b_fields =
+ b.GetAllFields();
+ for (size_t i = 0; i < std::min(a_fields.size(), b_fields.size()); ++i) {
+ const google::protobuf::FieldDescriptor& a_field = *a_fields[i];
+ const google::protobuf::FieldDescriptor& b_field = *b_fields[i];
+ if (AreDifferentFieldsInSameOneOf(a_field, b_field)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/proto_field_path.h b/rpc_fuzzing/proto_field_path.h
new file mode 100644
index 0000000..1042fe6
--- /dev/null
+++ b/rpc_fuzzing/proto_field_path.h
@@ -0,0 +1,120 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_
+#define FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_
+
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+
+namespace fuzztest::internal {
+
+// Represent a path to a (sub)field in a message.
+class FieldPath {
+ public:
+ // Append a field as a part to the path.
+ void AppendField(const google::protobuf::FieldDescriptor& field);
+ // Append another path to the path.
+ void AppendPath(const FieldPath& other);
+
+ // Given a message, follow the field path and return the parent message of the
+ // last field. If any repeated field in the middle is empty, return nullptr
+ const google::protobuf::Message* GetContainingMessageOfLastField(
+ const google::protobuf::Message& message) const;
+ // The mutable version fills in empty fields along the path, and thus never
+ // returns a null pointer.
+ google::protobuf::Message* MutableContainingMessageOfLastField(
+ google::protobuf::Message& message) const;
+
+ // Get the last field in the path.
+ const google::protobuf::FieldDescriptor& GetLastField() const;
+ // Get all the parts of the path.
+ const std::vector<const google::protobuf::FieldDescriptor*>& GetAllFields() const;
+
+ // Returns a string concatenating all the field parts by ".". For example, a
+ // field path with two fields "field1" and "field2" will return
+ // "field1.field2".
+ std::string ToString() const;
+
+ private:
+ bool CanAppend(const google::protobuf::FieldDescriptor& field) const;
+
+ // Make FieldPath hashable.
+ template <typename H>
+ friend H AbslHashValue(H h, const FieldPath& m) {
+ h = H::combine_contiguous(std::move(h), m.field_descriptors_.data(),
+ m.field_descriptors_.size());
+ return h;
+ }
+ friend bool operator==(const FieldPath& lhs, const FieldPath& rhs);
+
+ // A list of fields that allow us to locate inner fields.
+ std::vector<const google::protobuf::FieldDescriptor*> field_descriptors_;
+};
+
+bool operator==(const FieldPath& lhs, const FieldPath& rhs);
+
+// Utility function for copying fields in different proto messages, the fields
+// should have the same names and types, otherwise an assertion will be
+// triggered. If the source field is repeated but the sink is not, we copy the
+// first element of the source. If the source field is not repeated but the sink
+// is, we clear the sink and copy the source as the first element of the sink.
+// TODO(changochen): Make the condition of same name optional.
+// TODO(changochen): Select random element of a repeated fields for copying.
+void CopyField(const FieldPath& from_field, const google::protobuf::Message& from,
+ const FieldPath& to_field, google::protobuf::Message& to);
+
+FieldPath GetFieldPathWithDescriptor(const google::protobuf::Descriptor& descriptor,
+ std::string_view field_path_str);
+
+// Get a field path in a message from a path string. The path string should be
+// in format like "field1.subfield2.subfield3.xxx".
+// TODO(changochen): Return absl::StatusOr<FieldPath> instead of aborting. Make
+// this a constructor of field path.
+template <typename MessageT>
+FieldPath GetFieldPath(std::string_view field_path_str) {
+ const google::protobuf::Descriptor* descriptor = MessageT::GetDescriptor();
+ return GetFieldPathWithDescriptor(*descriptor, field_path_str);
+}
+
+// Utility function for collecting all (sub)fields in message type. If the
+// message type is recursive, then we only collect fields in a message when we
+// see its type for the first time.
+std::vector<FieldPath> CollectAllFields(
+ const google::protobuf::Descriptor& message_descriptor);
+
+// Checks whether two fields path are alternatives to each other (i.e., within
+// the same oneof group). If yes, only at most one of them should be set at any
+// given time.
+// For example, the following message: message Message {
+// oneof test_one_of {
+// Message1 a = 1;
+// int b = 2;
+// }
+// }
+//
+// message Message1 {
+// int a1 = 1;
+// int a2 = 2;
+// }
+//
+// We consider that a.a1 and b are within the same oneof, therefore
+// "alternatives", while a.a1 and a.a2 are not.
+bool AreOneOfAltearnatives(const FieldPath& a, const FieldPath& b);
+
+} // namespace fuzztest::internal
+
+#endif // FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_
diff --git a/rpc_fuzzing/proto_field_path_test.cc b/rpc_fuzzing/proto_field_path_test.cc
new file mode 100644
index 0000000..c00801a
--- /dev/null
+++ b/rpc_fuzzing/proto_field_path_test.cc
@@ -0,0 +1,514 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "./rpc_fuzzing/proto_field_path.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "./fuzztest/internal/test_protobuf.pb.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+
+namespace fuzztest::internal {
+
+using ::testing::ElementsAre;
+using ::testing::IsNull;
+using ::testing::NotNull;
+using ::testing::UnorderedElementsAre;
+
+namespace {
+
+TEST(FieldPathTest, GetFieldPathParseStringIntoAFieldPath) {
+ const FieldPath field_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+ EXPECT_THAT(
+ field_path.GetAllFields(),
+ ElementsAre(
+ LogOutUserRequest::GetDescriptor()->FindFieldByName("log_out_info"),
+ LogOutInfo::GetDescriptor()->FindFieldByName("session_info"),
+ SessionInfo::GetDescriptor()->FindFieldByName("session_id")));
+}
+
+TEST(FieldPathTest, GetFieldPathAbortsAtInvalidField) {
+ EXPECT_DEATH_IF_SUPPORTED(
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.invalid"),
+ "Invalid field name!");
+}
+
+MATCHER_P(FieldPathAsString, path_str, "") {
+ if (arg.GetAllFields().empty()) return path_str == "";
+ std::string result;
+ for (const google::protobuf::FieldDescriptor* part : arg.GetAllFields()) {
+ absl::StrAppend(&result, part->name(), ".");
+ }
+ return result.substr(0, result.size() - 1) == path_str;
+}
+
+TEST(FieldPathTest, HashAndEqualityCheck) {
+ const FieldPath path1 = GetFieldPath<LogInUserResponse>("session_id");
+ const FieldPath path2 = GetFieldPath<LogInUserResponse>("session_id");
+ EXPECT_EQ(path1, path2);
+}
+
+TEST(FieldPathTest, CopyFieldOfSameTypeAndSameNameInMessagesSucceeds) {
+ // Test copying LogInUserResponse.session_id to
+ // LogOutUserRequest.log_out_info.session_info.session_id.
+ LogInUserResponse log_in_response;
+ const FieldPath session_id_source_path =
+ GetFieldPath<LogInUserResponse>("session_id");
+ LogOutUserRequest log_out_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+
+ constexpr int64_t kSessionId = 0xdeadbeef;
+ log_in_response.set_session_id(kSessionId);
+ CopyField(session_id_source_path, log_in_response, session_id_sink_path,
+ log_out_request);
+
+ EXPECT_EQ(kSessionId,
+ log_out_request.log_out_info().session_info().session_id());
+}
+
+TEST(FieldPathTest, CopyingRepeatedFieldToNonRepeatedCopiesTheFirstElement) {
+ // Copy repeated `{val1, val2}` to non repeated `val2` results in `val2`.
+ constexpr int64_t kValue1 = 123;
+ constexpr int64_t kValue2 = 234;
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ message_containing_repeated_fields.add_field(kValue1);
+ message_containing_repeated_fields.add_field(kValue2);
+ const FieldPath source_path =
+ GetFieldPath<MessageContainingRepeatedFields>("field");
+ MessageNotContainingRepeatedFields message_not_containing_repeated_fields;
+ message_not_containing_repeated_fields.set_field(kValue2);
+ const FieldPath sink_path =
+ GetFieldPath<MessageNotContainingRepeatedFields>("field");
+
+ CopyField(source_path, message_containing_repeated_fields, sink_path,
+ message_not_containing_repeated_fields);
+
+ MessageNotContainingRepeatedFields expected_result;
+ expected_result.set_field(kValue1);
+ EXPECT_EQ(message_not_containing_repeated_fields.DebugString(),
+ expected_result.DebugString());
+}
+
+TEST(FieldPathTest,
+ CopyFieldDoesNothingIfSourceContainsEmptyRepeatedFieldInTheMiddle) {
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath source_path = GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ MessageNotContainingRepeatedFields message_not_containing_repeated_fields;
+ message_not_containing_repeated_fields.set_field2(123);
+ const MessageNotContainingRepeatedFields expected_result =
+ message_not_containing_repeated_fields;
+
+ const FieldPath sink_path =
+ GetFieldPath<MessageNotContainingRepeatedFields>("field2");
+
+ CopyField(source_path, message_containing_repeated_fields, sink_path,
+ message_not_containing_repeated_fields);
+ EXPECT_EQ(message_not_containing_repeated_fields.DebugString(),
+ expected_result.DebugString());
+}
+
+TEST(FieldPathTest,
+ CopyFieldSucceedsIfSinkContainEmptyRepeatedFieldInTheMiddle) {
+ constexpr int64_t kValue1 = 123;
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath sink_path = GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ MessageNotContainingRepeatedFields message_not_containing_repeated_fields;
+ message_not_containing_repeated_fields.set_field2(kValue1);
+
+ const FieldPath source_path =
+ GetFieldPath<MessageNotContainingRepeatedFields>("field2");
+
+ CopyField(source_path, message_not_containing_repeated_fields, sink_path,
+ message_containing_repeated_fields);
+ MessageContainingRepeatedFields expected_result;
+ expected_result.add_repeated_msg_field()->set_field2(kValue1);
+ EXPECT_EQ(message_containing_repeated_fields.DebugString(),
+ expected_result.DebugString());
+}
+
+TEST(FieldPathTest, CopyingEmptyRepeatedFieldToNonRepeatedClearsTheField) {
+ // Copy empty repeated `{}` to non repeated `val2` results in cleared field.
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath source_path =
+ GetFieldPath<MessageContainingRepeatedFields>("field");
+ MessageNotContainingRepeatedFields message_not_containing_repeated_fields;
+ message_not_containing_repeated_fields.set_field(123);
+ const FieldPath sink_path =
+ GetFieldPath<MessageNotContainingRepeatedFields>("field");
+
+ CopyField(source_path, message_containing_repeated_fields, sink_path,
+ message_not_containing_repeated_fields);
+
+ EXPECT_EQ(message_not_containing_repeated_fields.DebugString(),
+ MessageNotContainingRepeatedFields{}.DebugString());
+}
+
+TEST(FieldPathTest, CopyingNonRepeatedFieldToRepeatedFieldSetsFirstElement) {
+ // Copy non repeated `val2` to repeated `{val1, val2}` results in `{val2}`.
+ constexpr int64_t kValue1 = 123;
+ constexpr int64_t kValue2 = 234;
+ MessageNotContainingRepeatedFields message_not_containing_repeated_fields;
+ message_not_containing_repeated_fields.set_field(kValue2);
+ const FieldPath source_path =
+ GetFieldPath<MessageNotContainingRepeatedFields>("field");
+
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ message_containing_repeated_fields.add_field(kValue1);
+ message_containing_repeated_fields.add_field(kValue2);
+ const FieldPath sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>("field");
+
+ CopyField(source_path, message_not_containing_repeated_fields, sink_path,
+ message_containing_repeated_fields);
+
+ MessageContainingRepeatedFields expected_result;
+ expected_result.add_field(kValue2);
+ EXPECT_EQ(message_containing_repeated_fields.DebugString(),
+ expected_result.DebugString());
+}
+
+TEST(FieldPathTest, CopyingRepeatedFieldHandlesFieldOfMessageTypeCorrectly) {
+ MessageContainingRepeatedFields source_message, sink_message;
+ source_message.add_repeated_msg_field()->set_field2(0x123);
+ source_message.add_repeated_msg_field()->set_field2(0x234);
+ const FieldPath field_path =
+ GetFieldPath<MessageContainingRepeatedFields>("repeated_msg_field");
+
+ CopyField(field_path, source_message, field_path, sink_message);
+
+ EXPECT_EQ(sink_message.DebugString(), source_message.DebugString());
+}
+
+TEST(FieldPathTest,
+ CopyingSIngularSrcToRepeatedDstHandlesFieldOfMessageTypeCorrectly) {
+ MessageWithSingleInnerRepeated source_message;
+ source_message.mutable_repeated_msg_field()->set_field2(123);
+ MessageContainingRepeatedFields sink_message;
+ sink_message.add_repeated_msg_field()->set_field2(1);
+ sink_message.add_repeated_msg_field()->set_field2(1);
+ const FieldPath sink_field_path =
+ GetFieldPath<MessageContainingRepeatedFields>("repeated_msg_field");
+ const FieldPath source_field_path =
+ GetFieldPath<MessageWithSingleInnerRepeated>("repeated_msg_field");
+
+ CopyField(source_field_path, source_message, sink_field_path, sink_message);
+
+ MessageContainingRepeatedFields expected_result;
+ expected_result.add_repeated_msg_field()->set_field2(123);
+
+ EXPECT_EQ(sink_message.DebugString(), expected_result.DebugString());
+}
+
+TEST(FieldPathTest, CopyingRepeatedFieldToRepeatedFieldCopiesTheWholeVector) {
+ // Copy repeated `{val1, val2}` to repeated `{}` results in `{val1, val2}`.
+ constexpr int64_t kValue1 = 123;
+ constexpr int64_t kValue2 = 234;
+
+ MessageContainingRepeatedFields source, sink;
+ source.add_field(kValue1);
+ source.add_field(kValue2);
+
+ const FieldPath source_path =
+ GetFieldPath<MessageContainingRepeatedFields>("field");
+ const FieldPath sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>("field");
+
+ CopyField(source_path, source, sink_path, sink);
+
+ MessageContainingRepeatedFields expected_result;
+ expected_result.add_field(kValue1);
+ expected_result.add_field(kValue2);
+ EXPECT_EQ(sink.DebugString(), expected_result.DebugString());
+}
+
+TEST(FieldPathTest, CopyFieldOfSameTypeAndDifferentNameInMessagesFails) {
+ // Test copying GetUserPostsOptions.order to
+ // GetUserPostsRequest.options.
+ GetUserPostsOptions options;
+ const FieldPath order_source_path =
+ GetFieldPath<GetUserPostsOptions>("order");
+ GetUserPostsOptionsWithDifferentFieldNames options_with_different_field_names;
+ const FieldPath ordering_sink_path =
+ GetFieldPath<GetUserPostsOptionsWithDifferentFieldNames>("ordering");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ CopyField(order_source_path, options, ordering_sink_path,
+ options_with_different_field_names),
+ "Fields of mismatch names cannot be copied!");
+}
+
+TEST(FieldPathTest, CopyFieldOfDifferentTypeInMessagesFails) {
+ // Test copying LogInUserResponse.session_id to
+ // GetUserPostsRequest.options.
+ const LogInUserResponse log_in_response;
+ const FieldPath session_id_source_path =
+ GetFieldPath<LogInUserResponse>("session_id");
+ GetUserPostsRequest get_user_posts_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<GetUserPostsRequest>("options");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ CopyField(session_id_source_path, log_in_response, session_id_sink_path,
+ get_user_posts_request),
+ "Fields of mismatch types cannot be copied!");
+}
+
+TEST(FieldPathTest, CopyFieldHandlesEnumCorrectly) {
+ GetUserPostsOptions source_options, sink_options;
+ const FieldPath field_path = GetFieldPath<GetUserPostsOptions>("order");
+ source_options.set_order(SortingOrder::ASCENDING);
+
+ CopyField(field_path, source_options, field_path, sink_options);
+
+ EXPECT_EQ(sink_options.order(), SortingOrder::ASCENDING);
+}
+
+TEST(FieldPathTest, GetContainingMessageOfLastFieldReturnsContainingMessage) {
+ const LogOutUserRequest log_out_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+
+ const LogOutInfo& log_out_info = log_out_request.log_out_info();
+ const SessionInfo& session_info = log_out_info.session_info();
+
+ const google::protobuf::Message* parent_msg =
+ session_id_sink_path.GetContainingMessageOfLastField(log_out_request);
+
+ EXPECT_EQ(parent_msg, &session_info);
+}
+
+TEST(FieldPathTest,
+ GetContainingMessageReturnsNonNullForUnsetMiddleSingularField) {
+ LogOutUserRequest log_out_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+
+ EXPECT_THAT(
+ session_id_sink_path.GetContainingMessageOfLastField(log_out_request),
+ NotNull());
+}
+
+TEST(FieldPathTest,
+ GetContainingMessageReturnsNullForMissingMiddleRepeatedField) {
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath message_containing_repeated_fields_sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ EXPECT_THAT(
+ message_containing_repeated_fields_sink_path
+ .GetContainingMessageOfLastField(message_containing_repeated_fields),
+ IsNull());
+}
+
+TEST(FieldPathTest,
+ MutableContainingMessageAddMissingRepeatedFieldAndReturnsNonNull) {
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath message_containing_repeated_fields_sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ EXPECT_THAT(message_containing_repeated_fields_sink_path
+ .MutableContainingMessageOfLastField(
+ message_containing_repeated_fields),
+ NotNull());
+}
+
+TEST(FieldPathTest,
+ MutableContainingMessageAtNonEmptyRepeatedFieldWillNotAddField) {
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ message_containing_repeated_fields.add_repeated_msg_field()->set_field2(
+ 0x123);
+ const FieldPath message_containing_repeated_fields_sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ (void)message_containing_repeated_fields_sink_path
+ .MutableContainingMessageOfLastField(message_containing_repeated_fields);
+ EXPECT_EQ(
+ message_containing_repeated_fields.GetReflection()->FieldSize(
+ message_containing_repeated_fields,
+ message_containing_repeated_fields.GetDescriptor()->FindFieldByName(
+ "repeated_msg_field")),
+ 1);
+}
+
+TEST(FieldPathTest, MutableContainingMessageAtEmptyRepeatedFieldWillAddField) {
+ MessageContainingRepeatedFields message_containing_repeated_fields;
+ const FieldPath message_containing_repeated_fields_sink_path =
+ GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+ (void)message_containing_repeated_fields_sink_path
+ .MutableContainingMessageOfLastField(message_containing_repeated_fields);
+ EXPECT_EQ(
+ message_containing_repeated_fields.GetReflection()->FieldSize(
+ message_containing_repeated_fields,
+ message_containing_repeated_fields.GetDescriptor()->FindFieldByName(
+ "repeated_msg_field")),
+ 1);
+}
+
+TEST(FieldPathTest,
+ GetContainingMessageOfLastFieldReturnsFirstIndexOfRepeatedFields) {
+ MessageContainingRepeatedFields message;
+ message.add_repeated_msg_field();
+ message.add_repeated_msg_field();
+
+ FieldPath field_path = GetFieldPath<MessageContainingRepeatedFields>(
+ "repeated_msg_field.field2");
+
+ const google::protobuf::Message* parent_msg =
+ field_path.GetContainingMessageOfLastField(message);
+ EXPECT_EQ(parent_msg, &message.repeated_msg_field(0));
+}
+
+TEST(FieldPathTest, MutableContainingMessageOfLastFieldReturnsMutableMessage) {
+ LogOutUserRequest log_out_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+
+ const LogOutInfo& log_out_info = log_out_request.log_out_info();
+ const SessionInfo& session_info = log_out_info.session_info();
+
+ google::protobuf::Message* parent_msg =
+ session_id_sink_path.MutableContainingMessageOfLastField(log_out_request);
+ // MutableContainingMessageOfLastField might create new message so the address
+ // will be different.
+ EXPECT_NE(parent_msg, &session_info);
+}
+
+TEST(FieldPathTest, GetLastFieldReturnsLastFieldOfThePath) {
+ const LogOutUserRequest log_out_request;
+ const FieldPath session_id_sink_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+
+ EXPECT_EQ(&session_id_sink_path.GetLastField(),
+ SessionInfo::GetDescriptor()->FindFieldByName("session_id"));
+}
+
+TEST(FieldPathTest,
+ AppendFieldSuccessIfTheLastFieldIsTheAppendedFieldParentType) {
+ const FieldPath ground_truth =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info");
+ FieldPath field_path = GetFieldPath<LogOutUserRequest>("log_out_info");
+ field_path.AppendField(
+ *LogOutInfo::GetDescriptor()->FindFieldByName("session_info"));
+ EXPECT_EQ(ground_truth, field_path);
+}
+
+TEST(FieldPathTest,
+ AppendFieldFailsIfTheLastFieldIsNotTheAppendedFieldParentType) {
+ FieldPath field_path = GetFieldPath<LogInUserResponse>("success");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ field_path.AppendField(field_path.GetLastField()),
+ "The current last field in the path must be a message "
+ "containing the appended field.");
+}
+
+TEST(FieldPathTest,
+ AppendPathSuccessIfLastFieldOfAppendedPathIsContainedByTheCurrentPath) {
+ const FieldPath ground_truth =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+ FieldPath field_path1 = GetFieldPath<LogOutUserRequest>("log_out_info");
+ const FieldPath field_path2 =
+ GetFieldPath<LogOutInfo>("session_info.session_id");
+ field_path1.AppendPath(field_path2);
+ EXPECT_EQ(field_path1, ground_truth);
+}
+
+TEST(FieldPathTest,
+ AppendPathFailsIfLastFieldOfAppendedPathIsNotContainedByTheCurrentPath) {
+ FieldPath field_path1 = GetFieldPath<LogOutUserRequest>("log_out_info");
+ const FieldPath field_path2 = GetFieldPath<SessionInfo>("session_id");
+
+ EXPECT_DEATH_IF_SUPPORTED(
+ field_path1.AppendPath(field_path2),
+ "The current last field in the path must be a message "
+ "containing the first field in the other path.");
+}
+
+TEST(FieldPathTest, CollectAllFieldsReturnAllInnerFields) {
+ std::vector<FieldPath> all_fields =
+ CollectAllFields(*LogOutUserRequest::descriptor());
+
+ EXPECT_THAT(all_fields,
+ UnorderedElementsAre(
+ FieldPathAsString("log_out_info"),
+ FieldPathAsString("log_out_info.session_info"),
+ FieldPathAsString("log_out_info.session_id"),
+ FieldPathAsString("log_out_info.session_info.session_id")));
+}
+
+TEST(FieldPathTest, CollectAllFieldsSupportsGroupField) {
+ std::vector<FieldPath> all_fields =
+ CollectAllFields(*MessageWithGroup::descriptor());
+
+ EXPECT_THAT(all_fields,
+ UnorderedElementsAre(FieldPathAsString("groupfield"),
+ FieldPathAsString("groupfield.field1"),
+ FieldPathAsString("groupfield.field2")));
+}
+
+TEST(FieldPathTest, CopyFieldsSupportsGroupField) {
+ MessageWithGroup message_with_group1, message_with_group2;
+
+ message_with_group1.mutable_groupfield()->set_field1(123);
+ CopyField(GetFieldPath<MessageWithGroup>("groupfield"), message_with_group1,
+ GetFieldPath<MessageWithGroup>("groupfield"), message_with_group2);
+ EXPECT_EQ(message_with_group2.groupfield().field1(), 123);
+}
+
+TEST(FieldPathTest, CollectAllFieldsInRecursiveMessageSkipVisitedMessages) {
+ std::vector<FieldPath> all_fields =
+ CollectAllFields(*RecursiveNode::descriptor());
+ EXPECT_THAT(all_fields, ElementsAre(FieldPathAsString("value"),
+ FieldPathAsString("children")));
+}
+
+TEST(FieldPathTest, ToStringConcatAllFieldPartWithDot) {
+ const FieldPath field_path =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+ EXPECT_EQ(field_path.ToString(), "log_out_info.session_info.session_id");
+}
+
+TEST(FieldPathTest, AlternativeToReturnsTrueForFieldsWithinTheSameOneOf) {
+ EXPECT_TRUE(
+ AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v1"),
+ GetFieldPath<OneOfMessage>("oneof1_field2")));
+ EXPECT_TRUE(
+ AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v2"),
+ GetFieldPath<OneOfMessage>("oneof1_field2")));
+}
+
+TEST(FieldPathTest, AlternativeToReturnsFalseForFieldsNotWithinTheSameOneOf) {
+ EXPECT_FALSE(
+ AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1"),
+ GetFieldPath<OneOfMessage>("oneof1_field1.v1")));
+ EXPECT_FALSE(
+ AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v1"),
+ GetFieldPath<OneOfMessage>("oneof1_field1.v1")));
+ EXPECT_FALSE(
+ AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field2"),
+ GetFieldPath<OneOfMessage>("oneof2_field2")));
+}
+
+} // namespace
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_executor.cc b/rpc_fuzzing/rpc_executor.cc
new file mode 100644
index 0000000..c00dc2d
--- /dev/null
+++ b/rpc_fuzzing/rpc_executor.cc
@@ -0,0 +1,58 @@
+#include "./rpc_fuzzing/rpc_executor.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "./fuzztest/internal/logging.h"
+
+namespace fuzztest::internal {
+
+absl::Status RpcExecutor::Execute(RpcSequence& sequence) {
+ responses_.clear();
+ responses_.reserve(sequence.size());
+ absl::Status status;
+ for (RpcNode& node : sequence) {
+ status = ExecuteOne(node);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ return status;
+}
+
+const google::protobuf::Message& RpcExecutor::GetRpcNodeResponse(RpcNodeID node_id) {
+ FUZZTEST_INTERNAL_CHECK(node_id < responses_.size(),
+ "The dependency source doesn't exist!");
+ return *responses_[node_id];
+}
+
+std::vector<const google::protobuf::Message*> RpcExecutor::GetResponses() const {
+ std::vector<const google::protobuf::Message*> responses;
+ responses.reserve(responses_.size());
+ for (const std::unique_ptr<google::protobuf::Message>& response : responses_) {
+ responses.push_back(response.get());
+ }
+ return responses;
+}
+
+absl::Status RpcExecutor::ExecuteOne(RpcNode& node) {
+ for (const RpcDataFlowEdge& dep : node.dependencies()) {
+ // Assign value to dynamic fields.
+ const google::protobuf::Message& save_response = GetRpcNodeResponse(dep.from_node_id);
+ CopyField(dep.from_field, save_response, dep.to_field, node.request());
+ }
+
+ absl::StatusOr<std::unique_ptr<google::protobuf::Message>> response =
+ stub_->CallMethod(node.method(), node.request());
+ if (!response.ok()) return std::move(response).status();
+ responses_.push_back(*std::move(response));
+ return absl::OkStatus();
+}
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_executor.h b/rpc_fuzzing/rpc_executor.h
new file mode 100644
index 0000000..31463ab
--- /dev/null
+++ b/rpc_fuzzing/rpc_executor.h
@@ -0,0 +1,38 @@
+#ifndef FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_
+#define FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_
+
+#include <memory>
+#include <vector>
+
+#include "google/protobuf/message.h"
+#include "absl/log/die_if_null.h"
+#include "./rpc_fuzzing/rpc_sequence.h"
+#include "./rpc_fuzzing/rpc_stub.h"
+namespace fuzztest::internal {
+
+// The RpcExecutor fills the dynamic values from the previous responses
+// according to the dependency in the RpcSequence and sends modified requests to
+// the fuzzed service.
+class RpcExecutor {
+ public:
+ RpcExecutor(RpcStub* stub) : stub_(ABSL_DIE_IF_NULL(stub)) {}
+ // Send the rpc calls specified in `sequence`. The requests of `sequence` will
+ // be modified according to the dependency. Every time this method is called,
+ // the previous responses will be cleared.
+ absl::Status Execute(RpcSequence& sequence);
+ // Get the responses of the last executed sequence.
+ std::vector<const google::protobuf::Message*> GetResponses() const;
+
+ private:
+ // Fill the dynamic fields in the request, execute a single RPC call, saves
+ // and returns the status of the rpc.
+ absl::Status ExecuteOne(RpcNode& node);
+ const google::protobuf::Message& GetRpcNodeResponse(RpcNodeID node_id);
+
+ std::vector<std::unique_ptr<google::protobuf::Message>> responses_;
+ RpcStub* stub_;
+};
+
+} // namespace fuzztest::internal
+
+#endif // FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_
diff --git a/rpc_fuzzing/rpc_executor_test.cc b/rpc_fuzzing/rpc_executor_test.cc
new file mode 100644
index 0000000..178daa2
--- /dev/null
+++ b/rpc_fuzzing/rpc_executor_test.cc
@@ -0,0 +1,86 @@
+#include "./rpc_fuzzing/rpc_executor.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
+#include "./domain_tests/domain_testing.h"
+#include "./rpc_fuzzing/grpc_stub.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+#include "./rpc_fuzzing/rpc_sequence.h"
+#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+#include "grpcpp//security/server_credentials.h"
+#include "grpcpp//server.h"
+#include "grpcpp//server_builder.h"
+#include "grpcpp//support/channel_arguments.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+using ::testing::AllOf;
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::IsTrue;
+using ::testing::Ne;
+using ::testing::NotNull;
+using ::testing::Pointee;
+using ::testing::ResultOf;
+
+const google::protobuf::FieldDescriptor* GetField(const google::protobuf::Message& message,
+ absl::string_view field_name) {
+ return message.GetDescriptor()->FindFieldByName(field_name);
+}
+
+class RpcExecutorGrpcTest : public ::testing::Test {
+ protected:
+ RpcExecutorGrpcTest()
+ : server_(grpc::ServerBuilder()
+ .RegisterService(&mini_blogger_service_)
+ .BuildAndStart()),
+ stub_(server_->InProcessChannel(grpc::ChannelArguments())) {}
+
+ ~RpcExecutorGrpcTest() override { server_->Shutdown(); }
+
+ RpcNode GetRegisterUserRpcNode(absl::string_view user_name,
+ absl::string_view email,
+ absl::string_view password) {
+ auto request = std::make_unique<RegisterUserRequest>();
+ request->set_user_name(user_name);
+ request->set_password(password);
+ request->set_email(password);
+ const google::protobuf::MethodDescriptor* method =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ CHECK(method != nullptr);
+ return RpcNode(*method, std::move(request));
+ }
+
+ MiniBloggerGrpcService mini_blogger_service_;
+ std::unique_ptr<grpc::Server> server_ = nullptr;
+ GrpcStub stub_;
+};
+
+TEST_F(RpcExecutorGrpcTest, ExecutorCorrectlyExecutesRpcCalls) {
+ RpcNode node = GetRegisterUserRpcNode("test", "test_email@gmail.com", "123");
+ RpcSequence sequence = {node};
+ RpcExecutor executor(&stub_);
+ ASSERT_OK(executor.Execute(sequence));
+ EXPECT_THAT(executor.GetResponses(),
+ ElementsAre(Pointee(ResultOf(
+ [](const google::protobuf::Message& response) {
+ return response.GetReflection()->GetBool(
+ response, GetField(response, "success"));
+ },
+ IsTrue()))));
+}
+
+} // namespace
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_potential_dfg.cc b/rpc_fuzzing/rpc_potential_dfg.cc
new file mode 100644
index 0000000..ed9b04b
--- /dev/null
+++ b/rpc_fuzzing/rpc_potential_dfg.cc
@@ -0,0 +1,155 @@
+#include "./rpc_fuzzing/rpc_potential_dfg.h"
+
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "./fuzztest/internal/logging.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+// Get the type name of a field. If it is a message type, then return its
+// concrete message type name.
+std::string_view GetTypeNameForField(const google::protobuf::FieldDescriptor& field) {
+ if (field.type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) {
+ return field.message_type()->name();
+ } else {
+ return field.type_name();
+ }
+}
+
+// Describe a field path in a specific method's request or response.
+struct DetailedFieldInfo {
+ // Use pointer to keep it copy assignable.
+ const google::protobuf::MethodDescriptor* method;
+ bool in_request;
+ FieldPath field_path;
+};
+
+// Collect either all the fields in the response/request of a method.
+std::vector<DetailedFieldInfo> CollectFieldInfoInMethod(
+ const google::protobuf::MethodDescriptor& method, bool collect_request) {
+ const google::protobuf::Descriptor& message_descriptor =
+ collect_request ? *method.input_type() : *method.output_type();
+ std::vector<FieldPath> all_field_paths = CollectAllFields(message_descriptor);
+ std::vector<DetailedFieldInfo> result;
+ result.reserve(all_field_paths.size());
+ for (const auto& field_path : all_field_paths) {
+ result.push_back(DetailedFieldInfo{&method, collect_request, field_path});
+ }
+ return result;
+}
+
+std::vector<DetailedFieldInfo> CollectDefinitions(
+ const google::protobuf::MethodDescriptor& method) {
+ return CollectFieldInfoInMethod(method, /*collect_request=*/false);
+}
+
+std::vector<DetailedFieldInfo> CollectUses(
+ const google::protobuf::MethodDescriptor& method) {
+ return CollectFieldInfoInMethod(method, /*collect_request=*/true);
+}
+
+bool InSameMethod(const DetailedFieldInfo& a, const DetailedFieldInfo& b) {
+ return a.method->full_name() == b.method->full_name();
+}
+
+bool HasSameNameAndType(const DetailedFieldInfo& a,
+ const DetailedFieldInfo& b) {
+ const google::protobuf::FieldDescriptor& a_field = a.field_path.GetLastField();
+ const google::protobuf::FieldDescriptor& b_field = b.field_path.GetLastField();
+ return a_field.name() == b_field.name() &&
+ GetTypeNameForField(a_field) == GetTypeNameForField(b_field);
+}
+
+} // namespace
+
+void RpcPotentialDfgNode::AddDependency(
+ const google::protobuf::MethodDescriptor& source_method,
+ const FieldPath& source_field, const FieldPath& sink_field) {
+ dependencies_[sink_field].push_back(
+ PotentialDependencySource{&source_method, source_field});
+}
+
+void RpcPotentialDataFlowGraph::AddDependency(
+ const google::protobuf::MethodDescriptor& source_method,
+ const FieldPath& source_field, const google::protobuf::MethodDescriptor& sink_method,
+ const FieldPath& sink_field) {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(
+ nodes_.find(&source_method) != nodes_.end(), "No such source method!");
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(
+ nodes_.find(&sink_method) != nodes_.end(), "No such sink method!");
+ RpcPotentialDfgNode& sink_node = nodes_.find(&sink_method)->second;
+ sink_node.AddDependency(source_method, source_field, sink_field);
+}
+
+RpcPotentialDataFlowGraph::RpcPotentialDataFlowGraph(
+ const google::protobuf::ServiceDescriptor& service) {
+ for (int i = 0; i < service.method_count(); ++i) {
+ const google::protobuf::MethodDescriptor& method = *service.method(i);
+ nodes_.emplace(&method, RpcPotentialDfgNode(method));
+ }
+}
+
+RpcPotentialDataFlowGraph RpcPotentialDataFlowGraph::Create(
+ const google::protobuf::ServiceDescriptor& service_descriptor) {
+ // Collect all the definitions and uses.
+ std::vector<DetailedFieldInfo> all_definitions, all_uses;
+ for (int i = 0; i < service_descriptor.method_count(); ++i) {
+ const google::protobuf::MethodDescriptor& method = *service_descriptor.method(i);
+ std::vector<DetailedFieldInfo> definitions = CollectDefinitions(method);
+ std::vector<DetailedFieldInfo> uses = CollectUses(method);
+ all_definitions.insert(all_definitions.end(), definitions.begin(),
+ definitions.end());
+ all_uses.insert(all_uses.end(), uses.begin(), uses.end());
+ }
+
+ // For each pair of definition and use with same name and type, add a
+ // dependency (edge) in the graph.
+ RpcPotentialDataFlowGraph graph(service_descriptor);
+ for (const DetailedFieldInfo& define_field_info : all_definitions) {
+ for (const DetailedFieldInfo& use_field_info : all_uses) {
+ if (!InSameMethod(define_field_info, use_field_info) &&
+ HasSameNameAndType(define_field_info, use_field_info)) {
+ graph.AddDependency(*define_field_info.method,
+ define_field_info.field_path,
+ *use_field_info.method, use_field_info.field_path);
+ }
+ }
+ }
+ return graph;
+}
+
+const std::vector<RpcPotentialDfgNode::PotentialDependencySource>&
+RpcPotentialDfgNode::GetDependencies(const FieldPath& sink) const {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(
+ dependencies_.find(sink) != dependencies_.end(), "No such sink!");
+ return dependencies_.at(sink);
+}
+
+const RpcPotentialDfgNode& RpcPotentialDataFlowGraph::GetNode(
+ const google::protobuf::MethodDescriptor& method) const {
+ FUZZTEST_INTERNAL_CHECK_PRECONDITION(nodes_.find(&method) != nodes_.end(),
+ "No such method!");
+ return nodes_.at(&method);
+}
+
+const google::protobuf::ServiceDescriptor& GetServiceDescriptorByServiceName(
+ absl::string_view service_name) {
+ const google::protobuf::ServiceDescriptor* service =
+ google::protobuf::DescriptorPool::generated_pool()->FindServiceByName(
+ service_name.data());
+ FUZZTEST_INTERNAL_CHECK(service != nullptr,
+ std::string(service_name) + " Service not found!");
+ return *service;
+}
+
+RpcPotentialDataFlowGraph CreatePotentialDfgByServiceName(
+ absl::string_view service_name) {
+ return RpcPotentialDataFlowGraph::Create(
+ GetServiceDescriptorByServiceName(service_name));
+}
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_potential_dfg.h b/rpc_fuzzing/rpc_potential_dfg.h
new file mode 100644
index 0000000..ace3f37
--- /dev/null
+++ b/rpc_fuzzing/rpc_potential_dfg.h
@@ -0,0 +1,129 @@
+#ifndef FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_
+#define FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_
+
+#include <type_traits>
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/service.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "./fuzztest/internal/logging.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+
+namespace fuzztest::internal {
+
+// A `RpcPotentialDfgNode` describes the potential dependencies for the
+// fields in a request of a rpc method.
+class RpcPotentialDfgNode {
+ public:
+ RpcPotentialDfgNode(const google::protobuf::MethodDescriptor& method)
+ : method_(method) {}
+
+ // Establish a dependency from the `source_field` to the `sink_field` of
+ // current method.
+ void AddDependency(const google::protobuf::MethodDescriptor& source_method,
+ const FieldPath& source_field,
+ const FieldPath& sink_field);
+
+ // Get the method representing by the node.
+ const google::protobuf::MethodDescriptor& GetMethod() const { return method_; }
+
+ // Source of dependency, which should be in the response of a method.
+ struct PotentialDependencySource {
+ // `method` will never be null. Use a pointer to keep it copy assignable.
+ const google::protobuf::MethodDescriptor* method;
+ FieldPath field_path;
+ };
+
+ const absl::flat_hash_map<FieldPath, std::vector<PotentialDependencySource>>&
+ GetAllDependencies() const {
+ return dependencies_;
+ }
+
+ const std::vector<PotentialDependencySource>& GetDependencies(
+ const FieldPath& sink) const;
+ bool HasDependency() const { return !dependencies_.empty(); }
+
+ private:
+ const google::protobuf::MethodDescriptor& method_;
+ absl::flat_hash_map<FieldPath /*sink*/,
+ std::vector<PotentialDependencySource> /*source*/>
+ dependencies_;
+};
+
+// A `RpcPotentialDataFlowGraph` describes all the potential dependencies among
+// methods in a service. Each method is represented as a node in the graph.
+class RpcPotentialDataFlowGraph {
+ public:
+ // Get the node representing `method`.
+ const RpcPotentialDfgNode& GetNode(
+ const google::protobuf::MethodDescriptor& method) const;
+
+ // Create an `RpcPotentialDfg` with each method in the service being
+ // a node, and add dependencies between the methods.
+ static RpcPotentialDataFlowGraph Create(
+ const google::protobuf::ServiceDescriptor& service);
+
+ private:
+ RpcPotentialDataFlowGraph(const google::protobuf::ServiceDescriptor& service);
+ // Add a dependency on the field from the response of `source_method` to the
+ // request of `sink_method`.
+ void AddDependency(const google::protobuf::MethodDescriptor& source_method,
+ const FieldPath& source_field,
+ const google::protobuf::MethodDescriptor& sink_method,
+ const FieldPath& sink_field);
+
+ absl::flat_hash_map<const google::protobuf::MethodDescriptor*, RpcPotentialDfgNode>
+ nodes_;
+};
+
+template <typename T, typename = void>
+struct is_stubby_service : std::false_type {};
+
+template <typename ServiceT>
+struct is_stubby_service<ServiceT, std::void_t<decltype(ServiceT::descriptor)>>
+ : std::true_type {};
+
+template <typename T, typename = void>
+struct is_grpc_service : std::false_type {};
+
+template <typename ServiceT>
+struct is_grpc_service<ServiceT,
+ std::void_t<decltype(ServiceT::service_full_name)>>
+ : std::true_type {};
+
+// Get the service descriptor based on the service type. Currently support
+// protobuf services and stubby services.
+template <typename ServiceT>
+const google::protobuf::ServiceDescriptor& GetServiceDescriptor() {
+ if constexpr (std::is_same_v<ServiceT, google::protobuf::Service>) {
+ return *ServiceT::descriptor();
+ } else if constexpr (is_stubby_service<ServiceT>::value) {
+ return *google::protobuf::DescriptorPool::generated_pool()->FindServiceByName(
+ ServiceT::descriptor.full_name());
+ } else if constexpr (is_grpc_service<ServiceT>::value) {
+ return *google::protobuf::DescriptorPool::generated_pool()->FindServiceByName(
+ ServiceT::service_full_name());
+ } else {
+ FUZZTEST_INTERNAL_CHECK(false, "Unsupported service type!");
+ }
+}
+
+// Get the service descriptor based on the fully qualified service name such as
+// "package.ServiceName".
+const google::protobuf::ServiceDescriptor& GetServiceDescriptorByServiceName(
+ absl::string_view service_name);
+
+// Create an potential data flow graph for a service type.
+template <typename ServiceT>
+RpcPotentialDataFlowGraph CreatePotentialDfg() {
+ return RpcPotentialDataFlowGraph::Create(GetServiceDescriptor<ServiceT>());
+}
+
+RpcPotentialDataFlowGraph CreatePotentialDfgByServiceName(
+ absl::string_view service_name);
+
+} // namespace fuzztest::internal
+
+#endif // FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_
diff --git a/rpc_fuzzing/rpc_potential_dfg_test.cc b/rpc_fuzzing/rpc_potential_dfg_test.cc
new file mode 100644
index 0000000..c5c54a1
--- /dev/null
+++ b/rpc_fuzzing/rpc_potential_dfg_test.cc
@@ -0,0 +1,104 @@
+#include "./rpc_fuzzing/rpc_potential_dfg.h"
+
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+using ::testing::FieldsAre;
+using ::testing::UnorderedElementsAre;
+
+class RpcPotentialDfgTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ const google::protobuf::DescriptorPool* pool =
+ google::protobuf::DescriptorPool::generated_pool();
+ mini_blogger_serivce_ =
+ pool->FindServiceByName("fuzztest.internal.MiniBlogger");
+ ASSERT_TRUE(mini_blogger_serivce_ != nullptr);
+ log_out_user_method_ =
+ mini_blogger_serivce_->FindMethodByName("LogOutUser");
+ ASSERT_TRUE(log_out_user_method_ != nullptr);
+ log_in_user_method_ = mini_blogger_serivce_->FindMethodByName("LogInUser");
+ ASSERT_TRUE(log_in_user_method_ != nullptr);
+ register_user_method_ =
+ mini_blogger_serivce_->FindMethodByName("RegisterUser");
+ ASSERT_TRUE(register_user_method_ != nullptr);
+ get_user_posts_method_ =
+ mini_blogger_serivce_->FindMethodByName("GetUserPosts");
+ ASSERT_TRUE(get_user_posts_method_ != nullptr);
+ }
+ const google::protobuf::ServiceDescriptor* mini_blogger_serivce_;
+ const google::protobuf::MethodDescriptor* log_out_user_method_;
+ const google::protobuf::MethodDescriptor* log_in_user_method_;
+ const google::protobuf::MethodDescriptor* register_user_method_;
+ const google::protobuf::MethodDescriptor* get_user_posts_method_;
+};
+
+TEST_F(RpcPotentialDfgTest,
+ NodesNotDependingOnOthersHaveNoPotentialDependencies) {
+ RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>();
+
+ const RpcPotentialDfgNode& register_user_node =
+ dfg.GetNode(*register_user_method_);
+ // RegisterUser doesn't depend on others.
+ EXPECT_FALSE(register_user_node.HasDependency());
+
+ const RpcPotentialDfgNode& log_in_user_node =
+ dfg.GetNode(*log_in_user_method_);
+ // LogInUser doesn't depend on others.
+ EXPECT_FALSE(log_in_user_node.HasDependency());
+}
+
+TEST_F(RpcPotentialDfgTest, DependneciesSetUpBasedOnNameAndType) {
+ RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>();
+ const RpcPotentialDfgNode& get_user_posts_node =
+ dfg.GetNode(*get_user_posts_method_);
+ // GetUserPostsRequest.session_id depends on LogInUserResponse.session_id.
+ EXPECT_EQ(get_user_posts_node.GetAllDependencies().size(), 1);
+
+ FieldPath get_user_posts_request_field =
+ GetFieldPath<GetUserPostsRequest>("session_id");
+ std::vector<RpcPotentialDfgNode::PotentialDependencySource> all_dep_sources =
+ get_user_posts_node.GetDependencies(get_user_posts_request_field);
+
+ // Depends on the session_id of the response in `LogInUser`.
+ EXPECT_THAT(
+ all_dep_sources,
+ UnorderedElementsAre(FieldsAre(
+ log_in_user_method_, GetFieldPath<LogInUserResponse>("session_id"))));
+}
+
+TEST_F(RpcPotentialDfgTest, DependneciesSetUpOnInnerFields) {
+ RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>();
+ const RpcPotentialDfgNode& log_out_user_node =
+ dfg.GetNode(*log_out_user_method_);
+ // LogOutUserRequest.log_out_info.session_info.session_id and
+ // LogOutUserRequest.log_out_info.session_id depends on
+ // LogInUserResponse.session_id.
+ EXPECT_EQ(log_out_user_node.GetAllDependencies().size(), 2);
+
+ FieldPath log_out_user_request_field =
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id");
+ std::vector<RpcPotentialDfgNode::PotentialDependencySource> all_dep_sources =
+ log_out_user_node.GetDependencies(log_out_user_request_field);
+ EXPECT_EQ(all_dep_sources.size(), 1);
+
+ // Depends on the session_id of the response in `LogInUser`.
+ EXPECT_THAT(
+ all_dep_sources,
+ UnorderedElementsAre(FieldsAre(
+ log_in_user_method_, GetFieldPath<LogInUserResponse>("session_id"))));
+}
+
+} // namespace
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_sequence.cc b/rpc_fuzzing/rpc_sequence.cc
new file mode 100644
index 0000000..cb69534
--- /dev/null
+++ b/rpc_fuzzing/rpc_sequence.cc
@@ -0,0 +1,184 @@
+#include "./rpc_fuzzing/rpc_sequence.h"
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "google/protobuf/util/message_differencer.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/memory/memory.h"
+#include "absl/random/bit_gen_ref.h"
+#include "absl/random/random.h"
+#include "./fuzztest/internal/logging.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+std::vector<int> GetShuffledIndices(size_t size, absl::BitGenRef gen) {
+ std::vector<int> shuffled_indices(size);
+ std::iota(shuffled_indices.begin(), shuffled_indices.end(), 0);
+ std::shuffle(shuffled_indices.begin(), shuffled_indices.end(), gen);
+ return shuffled_indices;
+}
+
+std::vector<RpcNodeID> GetShuffledNodeIdsInGraph(const RpcDataFlowGraph& graph,
+ absl::BitGenRef gen) {
+ std::vector<RpcNodeID> shuffled_node_ids;
+ shuffled_node_ids.reserve(graph.NodeNum());
+ for (const auto& [node_id, unused_node] : graph.GetAllNodes()) {
+ shuffled_node_ids.push_back(node_id);
+ }
+ std::shuffle(shuffled_node_ids.begin(), shuffled_node_ids.end(), gen);
+ return shuffled_node_ids;
+}
+
+void RemapRpcNodeIdToSequenceIndex(
+ RpcSequence& sequence,
+ const absl::flat_hash_map<RpcNodeID, RpcNodeID>& id_map) {
+ for (RpcNode& node : sequence) {
+ for (RpcDataFlowEdge& edge : node.dependencies()) {
+ edge.from_node_id = id_map.at(edge.from_node_id);
+ }
+ }
+}
+
+void CollectTopologicalSortingOrder(const RpcDataFlowGraph& graph,
+ RpcNodeID current_id, absl::BitGenRef gen,
+ absl::flat_hash_set<RpcNodeID>& visited,
+ std::vector<RpcNodeID>& sorted_ids) {
+ if (visited.contains(current_id)) return;
+
+ visited.insert(current_id);
+ const RpcNode& node = graph.GetNode(current_id);
+ std::vector<RpcNodeID> shuffled_edge_indices =
+ GetShuffledIndices(node.dependencies().size(), gen);
+ for (int edge_index : shuffled_edge_indices) {
+ CollectTopologicalSortingOrder(graph,
+ node.dependencies()[edge_index].from_node_id,
+ gen, visited, sorted_ids);
+ }
+ sorted_ids.push_back(current_id);
+}
+
+std::vector<RpcNodeID> GetRandomTopologicalOrdering(
+ const RpcDataFlowGraph& graph) {
+ std::vector<RpcNodeID> sorted_node_ids;
+ sorted_node_ids.reserve(graph.NodeNum());
+ absl::flat_hash_set<RpcNodeID> visited;
+ absl::BitGen gen;
+ std::vector<RpcNodeID> shuffled_node_ids =
+ GetShuffledNodeIdsInGraph(graph, gen);
+ for (RpcNodeID node_id : shuffled_node_ids) {
+ CollectTopologicalSortingOrder(graph, node_id, gen, visited,
+ sorted_node_ids);
+ }
+ FUZZTEST_INTERNAL_CHECK(
+ sorted_node_ids.size() == graph.NodeNum(),
+ "Topological sort results in different number of nodes!");
+ return sorted_node_ids;
+}
+
+RpcSequence SequentializeGraphByOrder(
+ const RpcDataFlowGraph& graph, const std::vector<RpcNodeID>& node_order) {
+ RpcSequence result;
+ result.reserve(graph.NodeNum());
+ absl::flat_hash_map<RpcNodeID /* original node id */,
+ RpcNodeID /* sequence index */>
+ id_map;
+ for (size_t seq_idx = 0; seq_idx < node_order.size(); ++seq_idx) {
+ RpcNodeID original_node_id = node_order[seq_idx];
+ id_map[original_node_id] = seq_idx;
+ result.push_back(graph.GetNode(original_node_id));
+ }
+ RemapRpcNodeIdToSequenceIndex(result, id_map);
+ return result;
+}
+
+} // namespace
+
+RpcNode::RpcNode(const RpcNode& other)
+ : method_(other.method_),
+ request_(absl::WrapUnique(other.request_->New())),
+ dependencies_(other.dependencies_) {
+ request_->CopyFrom(*other.request_);
+}
+
+RpcNode& RpcNode::operator=(const RpcNode& other) {
+ method_ = other.method_;
+ request_ = absl::WrapUnique(other.request_->New());
+ request_->CopyFrom(*other.request_);
+ dependencies_ = other.dependencies_;
+ return *this;
+}
+
+bool operator==(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) {
+ return lhs.from_node_id == rhs.from_node_id &&
+ lhs.from_field == rhs.from_field && lhs.to_field == rhs.to_field;
+}
+
+bool operator!=(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) {
+ return !(lhs == rhs);
+}
+
+bool operator==(const RpcNode& lhs, const RpcNode& rhs) {
+ return lhs.method_->full_name() == rhs.method_->full_name() &&
+ google::protobuf::util::MessageDifferencer::Equals(*lhs.request_,
+ *rhs.request_) &&
+ lhs.dependencies_ == rhs.dependencies_;
+}
+
+bool operator!=(const RpcNode& lhs, const RpcNode& rhs) {
+ return !(lhs == rhs);
+}
+
+void RpcDataFlowGraph::AddNode(RpcNodeID id, RpcNode rpc_node) {
+ InvalidateNodeOrder();
+ rpc_nodes_.emplace(id, rpc_node);
+}
+
+void RpcDataFlowGraph::RemoveNode(RpcNodeID id) {
+ InvalidateNodeOrder();
+ rpc_nodes_.erase(id);
+}
+
+const RpcNode& RpcDataFlowGraph::GetNode(RpcNodeID id) const {
+ FUZZTEST_INTERNAL_CHECK(rpc_nodes_.contains(id), "Invalid id!");
+ return rpc_nodes_.at(id);
+}
+
+size_t RpcDataFlowGraph::NodeNum() const { return rpc_nodes_.size(); }
+
+RpcDataFlowGraph RpcDataFlowGraph::FromSequence(const RpcSequence& sequence) {
+ RpcDataFlowGraph graph;
+ for (int node_idx = 0; node_idx < sequence.size(); ++node_idx) {
+ graph.AddNode(node_idx, sequence[node_idx]);
+ }
+ graph.node_order_ = std::vector<RpcNodeID>(sequence.size());
+ std::iota(graph.node_order_->begin(), graph.node_order_->end(), 0);
+ return graph;
+}
+
+RpcSequence RpcDataFlowGraph::GetSequence() const {
+ if (!node_order_.has_value()) {
+ node_order_ = GetRandomTopologicalOrdering(*this);
+ }
+ return SequentializeGraphByOrder(*this, *node_order_);
+}
+
+std::vector<RpcNodeID>& RpcDataFlowGraph::GetOrderedNodeIds() const {
+ if (!node_order_.has_value()) {
+ node_order_ = GetRandomTopologicalOrdering(*this);
+ }
+ return *node_order_;
+}
+
+bool operator==(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) {
+ return lhs.GetSequence() == rhs.GetSequence();
+}
+
+bool operator!=(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) {
+ return !(lhs == rhs);
+}
+
+void RpcDataFlowGraph::RandomizeTopologicalOrdering() { InvalidateNodeOrder(); }
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_sequence.h b/rpc_fuzzing/rpc_sequence.h
new file mode 100644
index 0000000..b014022
--- /dev/null
+++ b/rpc_fuzzing/rpc_sequence.h
@@ -0,0 +1,116 @@
+#ifndef FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_
+#define FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_
+
+#include <memory>
+#include <optional>
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "absl/container/flat_hash_map.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+
+namespace fuzztest::internal {
+
+// An RPC call is identified by an id because we might have multiple calls to
+// the same rpc method. In the sequence, the id is also the index of the node.
+using RpcNodeID = std::int32_t;
+
+// Represents a data-flow edge between two RPC calls.
+struct RpcDataFlowEdge {
+ RpcNodeID from_node_id = -1;
+ // The message field in the response.
+ FieldPath from_field;
+ // The dynamic field in the request.
+ FieldPath to_field;
+};
+
+bool operator==(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs);
+bool operator!=(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs);
+
+// Represents a remote procedure call (a request). Parts of the request that are
+// coming from a previous response are represented with a data flow edge
+// (`RpcDataFlowEdge`).
+class RpcNode {
+ public:
+ RpcNode(const google::protobuf::MethodDescriptor& method,
+ std::unique_ptr<google::protobuf::Message> request)
+ : method_(&method), request_(std::move(request)) {}
+ RpcNode(const RpcNode&);
+ RpcNode(RpcNode&&) noexcept;
+ RpcNode& operator=(const RpcNode&);
+ RpcNode& operator=(RpcNode&&) = default;
+ const google::protobuf::MethodDescriptor& method() const { return *method_; }
+ google::protobuf::Message& request() { return *request_; }
+ const google::protobuf::Message& request() const { return *request_; }
+ void AddDependency(RpcDataFlowEdge dep) { dependencies_.push_back(dep); }
+ const std::vector<RpcDataFlowEdge>& dependencies() const {
+ return dependencies_;
+ }
+ std::vector<RpcDataFlowEdge>& dependencies() { return dependencies_; }
+
+ friend bool operator==(const RpcNode& lhs, const RpcNode& rhs);
+
+ private:
+ // The `method_` is ensured to be non-null.
+ const google::protobuf::MethodDescriptor* method_;
+ std::unique_ptr<google::protobuf::Message> request_;
+ std::vector<RpcDataFlowEdge> dependencies_;
+};
+
+bool operator==(const RpcNode& lhs, const RpcNode& rhs);
+bool operator!=(const RpcNode& lhs, const RpcNode& rhs);
+
+inline RpcNode::RpcNode(RpcNode&&) noexcept = default;
+
+// The RpcSequence represents a topologically sorted RpcDataFlowGraph.
+using RpcSequence = std::vector<RpcNode>;
+
+// The RpcDataFlowGraph represents a "remote procedure call session", i.e., a
+// set of RPC method calls and the data-flow dependencies between them.
+class RpcDataFlowGraph {
+ public:
+ void AddNode(RpcNodeID id, RpcNode rpc_node);
+ void RemoveNode(RpcNodeID id);
+ const RpcNode& GetNode(RpcNodeID id) const;
+ size_t NodeNum() const;
+
+ absl::flat_hash_map<RpcNodeID, RpcNode>& GetAllNodes() { return rpc_nodes_; }
+ const absl::flat_hash_map<RpcNodeID, RpcNode>& GetAllNodes() const {
+ return rpc_nodes_;
+ }
+
+ // Returns a topologically sorted sequence representation of the graph. Note
+ // that there are multiple possible topological orderings, from which it
+ // returns a random one. Multiple calls to the method return the same ordering
+ // (unless RandomizeTopologicalOrdering() is called).
+ RpcSequence GetSequence() const;
+
+ // Converts a sequence to a graph with the sequence node ordering.
+ static RpcDataFlowGraph FromSequence(const RpcSequence& sequence);
+
+ // Randomizes the topological ordering of the nodes.
+ void RandomizeTopologicalOrdering();
+
+ // Returns the ordering of the nodes corresponding to that returned by
+ // `GetSequence`.
+ std::vector<RpcNodeID>& GetOrderedNodeIds() const;
+
+ private:
+ // Invalidates the node order so that if the order is needed again it will be
+ // re-computed on demand. This optimization avoids unnecessary recomputation
+ // every time the graph is changed because the recomputation can be costly.
+ void InvalidateNodeOrder() const { node_order_.reset(); }
+ absl::flat_hash_map<RpcNodeID, RpcNode> rpc_nodes_;
+ // `GetSequence` might change the order so we need `mutable` for `GetSequence`
+ // to be used in `FromValue` in the RpcSession domain.
+ mutable std::optional<std::vector<RpcNodeID>> node_order_;
+};
+
+// Returns true if the graphs have the same nodes and ordering.
+bool operator==(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs);
+bool operator!=(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs);
+
+} // namespace fuzztest::internal
+
+#endif // FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_
diff --git a/rpc_fuzzing/rpc_sequence_test.cc b/rpc_fuzzing/rpc_sequence_test.cc
new file mode 100644
index 0000000..768d630
--- /dev/null
+++ b/rpc_fuzzing/rpc_sequence_test.cc
@@ -0,0 +1,280 @@
+#include "./rpc_fuzzing/rpc_sequence.h"
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/util/message_differencer.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+using ::testing::Address;
+using ::testing::ElementsAre;
+using ::testing::Field;
+using ::testing::Property;
+using ::testing::UnorderedElementsAreArray;
+
+TEST(RpcNodeTest, AssignOperator) {
+ constexpr RpcNodeID kFromNodeId = 1;
+ const google::protobuf::MethodDescriptor* method_descriptor =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ CHECK(method_descriptor != nullptr);
+ RpcNode node1 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+ node1.AddDependency(RpcDataFlowEdge{kFromNodeId});
+
+ RpcNode node2 = node1;
+
+ // Check that the two objects are equal.
+ EXPECT_EQ(node1.method().full_name(), node2.method().full_name());
+ EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(),
+ node2.request()));
+ EXPECT_THAT(node2.dependencies(),
+ ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId)));
+}
+
+TEST(RpcNodeTest, MoveAssignOperator) {
+ constexpr RpcNodeID kFromNodeId = 1;
+ const google::protobuf::MethodDescriptor* method_descriptor =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ CHECK(method_descriptor != nullptr);
+ RpcNode node1 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+ node1.AddDependency(RpcDataFlowEdge{kFromNodeId});
+
+ RpcNode clone = node1;
+ RpcNode node2 = std::move(clone);
+
+ // Check that the two objects are equal.
+ EXPECT_EQ(node1.method().full_name(), node2.method().full_name());
+ EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(),
+ node2.request()));
+ EXPECT_THAT(node2.dependencies(),
+ ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId)));
+}
+
+TEST(RpcNodeTest, CopyConstructor) {
+ constexpr RpcNodeID kFromNodeId = 1;
+ const google::protobuf::MethodDescriptor* method_descriptor =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ CHECK(method_descriptor != nullptr);
+ RpcNode node1 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+ node1.AddDependency(RpcDataFlowEdge{kFromNodeId});
+
+ RpcNode node2(node1);
+
+ // Check that the two objects are equal.
+ EXPECT_EQ(node1.method().full_name(), node2.method().full_name());
+ EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(),
+ node2.request()));
+ EXPECT_THAT(node2.dependencies(),
+ ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId)));
+}
+
+TEST(RpcNodeTest, EqualityTest) {
+ constexpr RpcNodeID kFromNodeId = 1;
+ const google::protobuf::MethodDescriptor* method_descriptor =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ ASSERT_NE(method_descriptor, nullptr);
+ RpcNode node1 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+ node1.AddDependency(RpcDataFlowEdge{kFromNodeId});
+
+ RpcNode node2(node1);
+
+ EXPECT_EQ(node1, node2);
+}
+
+TEST(RpcNodeTest, InEqualityTest) {
+ constexpr RpcNodeID kFromNodeId = 1;
+ const google::protobuf::MethodDescriptor* method_descriptor =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ ASSERT_NE(method_descriptor, nullptr);
+ RpcNode node1 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+ node1.AddDependency(RpcDataFlowEdge{kFromNodeId});
+
+ RpcNode node2 =
+ RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>());
+
+ EXPECT_NE(node1, node2);
+}
+
+TEST(RpcDataFlowEdgeTest, EqualityTest) {
+ RpcDataFlowEdge edge1 =
+ RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<GetUserPostsRequest>("session_id")};
+ RpcDataFlowEdge edge2 =
+ RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<GetUserPostsRequest>("session_id")};
+
+ EXPECT_EQ(edge1, edge2);
+}
+
+TEST(RpcDataFlowEdgeTest, InEqualityTest) {
+ RpcDataFlowEdge edge1 =
+ RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<GetUserPostsRequest>("session_id")};
+ RpcDataFlowEdge edge2 = RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")};
+
+ EXPECT_NE(edge1, edge2);
+}
+
+class RpcDataFlowGraphTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ log_in_user_method_ =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.LogInUser");
+ log_out_user_method_ =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.LogOutUser");
+ register_user_method_ =
+ google::protobuf::DescriptorPool::generated_pool()->FindMethodByName(
+ "fuzztest.internal.MiniBlogger.RegisterUser");
+ }
+
+ RpcNode CreateLogInUserNode(RpcNodeID log_in_node_id) {
+ return RpcNode(*log_in_user_method_, std::make_unique<LogInUserRequest>());
+ }
+
+ RpcNode CreateRegisterUserNode(RpcNodeID log_out_node_id) {
+ return RpcNode(*register_user_method_,
+ std::make_unique<RegisterUserRequest>());
+ }
+
+ std::pair<RpcNode, RpcNode> CreateLogInOutUserNodePairWithDep(
+ RpcNodeID log_in_node_id) {
+ // Set up a node for LogInUser and LogOutUser. Establish a dependency from
+ // LogInUserResponse.session_id to
+ // LogOutUserRequest.log_out_info.session_info.session_id.
+ RpcNode log_in_user_node(*log_in_user_method_,
+ std::make_unique<LogInUserResponse>());
+ RpcNode log_out_user_node(*log_out_user_method_,
+ std::make_unique<LogOutUserRequest>());
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ log_in_node_id, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>(
+ "log_out_info.session_info.session_id")});
+ return std::make_pair(log_in_user_node, log_out_user_node);
+ }
+
+ const google::protobuf::MethodDescriptor *log_in_user_method_, *log_out_user_method_,
+ *register_user_method_;
+ RpcNodeID log_in_node_id_ = 0x123;
+ RpcNodeID log_out_node_id_ = 0x234;
+};
+
+TEST_F(RpcDataFlowGraphTest, GetSequenceReturnTologicalSortedSequence) {
+ RpcDataFlowGraph graph;
+ auto [log_in_user_node, log_out_user_node] =
+ CreateLogInOutUserNodePairWithDep(log_in_node_id_);
+ // Add LogOutUser first, so that the topological sort should change its order.
+ graph.AddNode(log_out_node_id_, log_out_user_node);
+ graph.AddNode(log_in_node_id_, log_in_user_node);
+
+ RpcSequence sequence = graph.GetSequence();
+ EXPECT_THAT(
+ sequence,
+ ElementsAre(Property(&RpcNode::method, Address(log_in_user_method_)),
+ Property(&RpcNode::method, Address(log_out_user_method_))));
+}
+
+TEST_F(RpcDataFlowGraphTest, GetSequenceReMapAllTheRpcNodeIdToSequenceIndex) {
+ auto [log_in_user_node, log_out_user_node] =
+ CreateLogInOutUserNodePairWithDep(log_in_node_id_);
+ RpcDataFlowGraph graph;
+ graph.AddNode(log_in_node_id_, log_in_user_node);
+ graph.AddNode(log_out_node_id_, log_out_user_node);
+
+ RpcSequence sequence = graph.GetSequence();
+ EXPECT_EQ(&sequence[1].method(), log_out_user_method_);
+ EXPECT_THAT(sequence[1].dependencies(),
+ ElementsAre(Field(&RpcDataFlowEdge::from_node_id,
+ 0 /* remap `kNode1Id` to 0*/)));
+}
+
+TEST_F(RpcDataFlowGraphTest, GetSequenceReturnRandomizedTopoSortedSequence) {
+ const std::string_view expected_methods_of_first_nodes[] = {"LogInUser",
+ "RegisterUser"};
+ absl::flat_hash_set<std::string> methods_of_first_nodes;
+ for (int i = 0; i < 100; ++i) {
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, CreateLogInUserNode(log_in_node_id_));
+ graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_));
+ methods_of_first_nodes.insert(graph.GetSequence()[0].method().name());
+ }
+
+ EXPECT_THAT(methods_of_first_nodes,
+ UnorderedElementsAreArray(expected_methods_of_first_nodes));
+}
+
+TEST_F(RpcDataFlowGraphTest, FromSequenceRecoversNodesAndOrdering) {
+ RpcSequence sequence = {CreateLogInUserNode(log_in_node_id_),
+ CreateRegisterUserNode(log_out_node_id_)};
+ RpcDataFlowGraph graph = RpcDataFlowGraph::FromSequence(sequence);
+ EXPECT_EQ(graph.GetSequence(), sequence);
+}
+
+TEST_F(RpcDataFlowGraphTest,
+ GetSequenceReturnTheSameRandomizedTopoSortedSequence) {
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, CreateLogInUserNode(log_in_node_id_));
+ graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_));
+ absl::string_view method_name = graph.GetSequence()[0].method().name();
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_EQ(method_name, graph.GetSequence()[0].method().name());
+ }
+}
+
+TEST_F(
+ RpcDataFlowGraphTest,
+ GetSequenceReturnDifferentRandomizedTopoSortedSequenceAfterModification) {
+ const std::string_view expected_methods_of_first_nodes[] = {"LogInUser",
+ "RegisterUser"};
+ absl::flat_hash_set<std::string> methods_of_first_nodes;
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, CreateLogInUserNode(log_in_node_id_));
+ graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_));
+ for (int i = 0; i < 100; ++i) {
+ graph.RemoveNode(1);
+ graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_));
+ methods_of_first_nodes.insert(graph.GetSequence()[0].method().name());
+ }
+ EXPECT_THAT(methods_of_first_nodes,
+ UnorderedElementsAreArray(expected_methods_of_first_nodes));
+}
+
+TEST_F(RpcDataFlowGraphTest,
+ GetRandomSequenceAlwaysReturnRandomizedTopoSortedSequence) {
+ const std::string_view expected_methods_of_first_nodes[] = {"LogInUser",
+ "RegisterUser"};
+ absl::flat_hash_set<std::string> methods_of_first_nodes;
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, CreateLogInUserNode(log_in_node_id_));
+ graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_));
+ for (int i = 0; i < 100; ++i) {
+ graph.RandomizeTopologicalOrdering();
+ methods_of_first_nodes.insert(graph.GetSequence()[0].method().name());
+ }
+
+ EXPECT_THAT(methods_of_first_nodes,
+ UnorderedElementsAreArray(expected_methods_of_first_nodes));
+}
+
+} // namespace
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_session.h b/rpc_fuzzing/rpc_session.h
new file mode 100644
index 0000000..588ace8
--- /dev/null
+++ b/rpc_fuzzing/rpc_session.h
@@ -0,0 +1,608 @@
+#ifndef FUZZTEST_RPC_FUZZING_RPC_SESSION_H_
+#define FUZZTEST_RPC_FUZZING_RPC_SESSION_H_
+
+#include <algorithm>
+#include <memory>
+#include <optional>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/bit_gen_ref.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "./fuzztest/internal/domains/domain_base.h"
+#include "./fuzztest/internal/domains/protobuf_domain_impl.h"
+#include "./fuzztest/internal/domains/value_mutation_helpers.h"
+#include "./fuzztest/internal/logging.h"
+#include "./fuzztest/internal/serialization.h"
+#include "./fuzztest/internal/table_of_recent_compares.h"
+#include "./fuzztest/internal/type_support.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+#include "./rpc_fuzzing/rpc_executor.h"
+#include "./rpc_fuzzing/rpc_potential_dfg.h"
+#include "./rpc_fuzzing/rpc_sequence.h"
+
+namespace fuzztest {
+
+namespace internal {
+
+// A helper domain that allows its inner domain to be lazily initialized.
+template <typename DomainT, typename... Args>
+class Lazy : public DomainBase<Lazy<DomainT, Args...>, value_type_t<DomainT>,
+ corpus_type_t<DomainT>> {
+ public:
+ using typename Lazy::DomainBase::corpus_type;
+ using typename Lazy::DomainBase::value_type;
+
+ Lazy(Args&&... args) : args_(std::forward<Args>(args)...) {}
+
+ Lazy(const Lazy& other) {
+ if (other.inner_ != nullptr) {
+ inner_ = std::make_unique<DomainT>(*other.inner_);
+ }
+ args_ = other.args_;
+ }
+
+ Lazy(Lazy&& other) noexcept = default;
+ Lazy& operator=(Lazy&& other) = default;
+ corpus_type Init(absl::BitGenRef prng) { return GetInnerDomain().Init(prng); }
+
+ void Mutate(corpus_type& corpus_value, absl::BitGenRef prng,
+ bool shrink_only) {
+ GetInnerDomain().Mutate(corpus_value, prng, shrink_only);
+ }
+
+ value_type GetValue(const corpus_type& corpus_value) const {
+ return GetInnerDomain().GetValue(corpus_value);
+ }
+
+ std::optional<corpus_type> FromValue(const value_type& v) const {
+ return GetInnerDomain().FromValue(v);
+ }
+
+ std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
+ return GetInnerDomain().ParseCorpus(obj);
+ }
+
+ IRObject SerializeCorpus(const corpus_type& corpus_value) const {
+ return GetInnerDomain().SerializeCorpus(corpus_value);
+ }
+
+ absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
+ return GetInnerDomain().ValidateCorpusValue(corpus_value);
+ }
+
+ auto GetPrinter() const { return GetInnerDomain().GetPrinter(); }
+
+ private:
+ template <std::size_t... Is>
+ std::unique_ptr<DomainT> CreateInner(std::tuple<Args...> tup,
+ std::index_sequence<Is...>) const {
+ return std::make_unique<DomainT>(std::move(std::get<Is>(tup))...);
+ }
+
+ const DomainT& GetInnerDomain() const {
+ if (inner_ == nullptr) {
+ inner_ = CreateInner(args_, std::index_sequence_for<Args...>{});
+ }
+ return *inner_;
+ }
+
+ DomainT& GetInnerDomain() {
+ if (inner_ == nullptr) {
+ inner_ = CreateInner(args_, std::index_sequence_for<Args...>{});
+ }
+ return *inner_;
+ }
+ mutable std::unique_ptr<DomainT> inner_ = nullptr;
+ mutable std::tuple<Args...> args_;
+};
+
+template <class T>
+const T& PickRandomElement(const std::vector<T>& vec, absl::BitGenRef prng) {
+ FUZZTEST_INTERNAL_CHECK(!vec.empty(),
+ "Cannot pick elements from an empy vector.");
+ return vec[absl::Uniform(prng, size_t{0}, vec.size())];
+}
+
+// Rpc session domain. It generates RPC sequences with data-flows between
+// requests and responses of calls. The provided RpcSequence can be consumed by
+// the RpcExecutor.
+template <typename ServiceT = void>
+class RpcSessionImpl : public DomainBase<RpcSessionImpl<ServiceT>, RpcSequence,
+ RpcDataFlowGraph> {
+ public:
+ using typename RpcSessionImpl::DomainBase::corpus_type;
+ using typename RpcSessionImpl::DomainBase::value_type;
+
+ // `service_factory` should return the fully qualified service name such as
+ // `package.Service`.
+ RpcSessionImpl(std::function<absl::string_view()> service_factory)
+ : abstract_dfg_(CreatePotentialDfgByServiceName(service_factory())) {
+ const google::protobuf::ServiceDescriptor& desc =
+ GetServiceDescriptorByServiceName(service_factory());
+ Initialize(desc);
+ }
+
+ template <typename T = ServiceT,
+ typename = std::enable_if_t<is_stubby_service<T>::value ||
+ is_grpc_service<T>::value ||
+ std::is_same_v<T, google::protobuf::Service>>>
+ RpcSessionImpl() : abstract_dfg_(CreatePotentialDfg<ServiceT>()) {
+ const google::protobuf::ServiceDescriptor& desc = GetServiceDescriptor<ServiceT>();
+ Initialize(desc);
+ }
+
+ // Generate a sequence containing only a single rpc call. Such calls should be
+ // the "roots" in the data flow graph and depend on no other nodes.
+ corpus_type Init(absl::BitGenRef prng) {
+ const google::protobuf::MethodDescriptor& method =
+ *PickRandomElement(all_methods_, prng);
+ RpcDataFlowGraph result;
+ auto& request_domain = GetRequestDomain(method);
+ RpcNode call_node(method,
+ request_domain.GetValue(request_domain.Init(prng)));
+ result.AddNode(0 /* the first node */, std::move(call_node));
+ return result;
+ }
+
+ void Mutate(corpus_type& graph, absl::BitGenRef prng, bool only_shrink) {
+ if (only_shrink) {
+ RunOne(
+ prng, [&] { DeleteTailCall(graph, prng); },
+ [&] { MutateStaticField(graph, prng, only_shrink); });
+ } else {
+ RunOne(
+ prng, [&] { InsertTailCall(graph, prng); },
+ [&] { DeleteTailCall(graph, prng); },
+ [&] { MutateStaticField(graph, prng, only_shrink); });
+ }
+ }
+
+ value_type GetValue(const corpus_type& graph) const {
+ return graph.GetSequence();
+ }
+
+ std::optional<corpus_type> FromValue(const value_type& v) const {
+ corpus_type result;
+ for (size_t i = 0; i < v.size(); ++i) {
+ result.AddNode(i, v[i]);
+ }
+ return result;
+ }
+
+ std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
+ auto subs = obj.Subs();
+ if (!subs || subs->empty()) {
+ return std::nullopt;
+ }
+ RpcSequence sequence;
+ sequence.reserve(subs->size());
+ for (const auto& node_obj : *subs) {
+ if (!ValidateRpcNodeObjFormat(node_obj)) return std::nullopt;
+ auto node = ParseRpcNodeMethodAndRequest(node_obj);
+ if (!node) return std::nullopt;
+ auto edges_sub = (*node_obj.Subs())[2].Subs();
+ for (const auto& edge_obj : *edges_sub) {
+ std::optional<RpcDataFlowEdge> edge =
+ ParseRpcDataFlowEdge(edge_obj, sequence, node->method());
+ if (!edge) return std::nullopt;
+ node->AddDependency(*edge);
+ }
+ sequence.push_back(*std::move(node));
+ }
+ return RpcDataFlowGraph::FromSequence(sequence);
+ }
+
+ // Serializes the graph in the topological-sorting order.
+ IRObject SerializeCorpus(const corpus_type& graph) const {
+ IRObject result;
+ std::vector<IRObject>& result_subs = result.MutableSubs();
+ // We serialize the sequence instead to avoid storing the RpcNodeID.
+ for (const RpcNode& node : graph.GetSequence()) {
+ result_subs.push_back(SerializeRpcNode(node));
+ }
+ return result;
+ }
+
+ // TODO(changochen): Implemented in later CL.
+ auto GetPrinter() const { return MonostatePrinter{}; }
+
+ // Check whether every RpcNode:
+ // 1. Only depend on previous nodes. (No cyclic dependencies)
+ // 2. Every dependency matches one of the potential dependencies.
+ absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
+ absl::flat_hash_set<RpcNodeID> previous_node_ids;
+ for (RpcNodeID node_id : corpus_value.GetOrderedNodeIds()) {
+ const RpcNode& node = corpus_value.GetNode(node_id);
+ const RpcPotentialDfgNode& dfg_node =
+ abstract_dfg_.GetNode(node.method());
+ absl::flat_hash_set<FieldPath> all_sink_paths;
+ for (const RpcDataFlowEdge& edge : node.dependencies()) {
+ // The current node should only depend on previous nodes.
+ if (!previous_node_ids.contains(edge.from_node_id))
+ return absl::InvalidArgumentError(
+ "The dependencies should only come from previously executed "
+ "nodes.");
+ if (!dfg_node.GetAllDependencies().contains(edge.to_field)) {
+ return absl::InvalidArgumentError(
+ "The sink field should be defined in the potential data flow "
+ "graph.");
+ }
+
+ if (all_sink_paths.contains(edge.to_field)) {
+ return absl::InvalidArgumentError(
+ "One sink field should have at most one concrete dependency!");
+ }
+ all_sink_paths.insert(edge.to_field);
+
+ // Check whether the dependency source from `edge` actually matches one
+ // of the potential dependency source.
+ if (!HasPotentialSourceForEdge(edge, corpus_value, dfg_node)) {
+ return absl::InvalidArgumentError(
+ "The dependency is not defined in the potential data flow "
+ "graph.");
+ }
+ }
+
+ previous_node_ids.insert(node_id);
+ }
+ return absl::OkStatus();
+ }
+
+ private:
+ void Initialize(const google::protobuf::ServiceDescriptor& desc) {
+ for (int i = 0; i < desc.method_count(); ++i) {
+ const google::protobuf::MethodDescriptor* method = desc.method(i);
+ all_methods_.push_back(method);
+
+ std::function<const google::protobuf::Message*()> get_request_prototype =
+ [method]() {
+ return google::protobuf::MessageFactory::generated_factory()->GetPrototype(
+ method->input_type());
+ };
+ request_domains_.emplace(
+ method, ProtobufDomainUntypedImpl<google::protobuf::Message>(
+ PrototypePtr<google::protobuf::Message>(get_request_prototype),
+ /*use_lazy_initialization=*/false));
+ }
+ }
+
+ bool AllPotentialDepsInDfgNodeCanBeSatisfied(
+ const absl::flat_hash_set<FieldPath>& satisfied_field_paths,
+ const RpcPotentialDfgNode& dfg_node) const {
+ for (const auto& [field_path, unused] : dfg_node.GetAllDependencies()) {
+ if (!satisfied_field_paths.contains(field_path)) {
+ bool can_be_satisfied = false;
+ for (const FieldPath& satisfied_field_path : satisfied_field_paths) {
+ if (AreOneOfAltearnatives(satisfied_field_path, field_path)) {
+ can_be_satisfied = true;
+ break;
+ }
+ }
+ if (!can_be_satisfied) return false;
+ }
+ }
+ return true;
+ }
+ // Check whether the source of the edge is defined in the potential data flow
+ // graph.
+ bool HasPotentialSourceForEdge(const RpcDataFlowEdge& edge,
+ const RpcDataFlowGraph& graph,
+ const RpcPotentialDfgNode& dfg_node) const {
+ for (const RpcPotentialDfgNode::PotentialDependencySource&
+ potential_source : dfg_node.GetDependencies(edge.to_field)) {
+ const RpcNode& from_node = graph.GetNode(edge.from_node_id);
+ if (potential_source.method == &from_node.method() &&
+ potential_source.field_path == edge.from_field) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool ValidateRpcNodeObjFormat(const IRObject& obj) const {
+ auto node_subs = obj.Subs();
+ // Each node has a sub of size 3: method, request, dependencies.
+ if (!node_subs || node_subs->size() != 3) {
+ return false;
+ }
+ auto edge_subs = (*node_subs)[2].Subs();
+ if (!edge_subs) {
+ return false;
+ }
+ // Each edge has a sub of size 3: from_id, from_field, to_field.
+ for (const auto& edge_sub : *edge_subs) {
+ if (!edge_sub.Subs() || edge_sub.Subs()->size() != 3) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ IRObject SerializeRpcDataFlowEdge(const RpcDataFlowEdge& edge) const {
+ IRObject edge_obj;
+ std::vector<IRObject>& edge_subs = edge_obj.MutableSubs();
+ edge_subs.push_back(IRObject::FromCorpus(edge.from_node_id));
+ edge_subs.push_back(IRObject::FromCorpus(edge.from_field.ToString()));
+ edge_subs.push_back(IRObject::FromCorpus(edge.to_field.ToString()));
+ return edge_obj;
+ }
+
+ IRObject SerializeRpcNode(const RpcNode& node) const {
+ // The serialized object will be like
+ // [ method_name, request_proto_buf_string, [edges]]
+ IRObject node_obj;
+ std::vector<IRObject>& node_subs = node_obj.MutableSubs();
+ node_subs.push_back(IRObject::FromCorpus(node.method().full_name()));
+ node_subs.push_back(
+ IRObject::FromCorpus(node.request().SerializeAsString()));
+ IRObject edges_obj;
+ if (!node.dependencies().empty()) {
+ std::vector<IRObject>& edges_subs = edges_obj.MutableSubs();
+ for (const RpcDataFlowEdge& edge : node.dependencies()) {
+ edges_subs.push_back(SerializeRpcDataFlowEdge(edge));
+ }
+ }
+ node_subs.push_back(std::move(edges_obj));
+ return node_obj;
+ }
+
+ std::optional<RpcDataFlowEdge> ParseRpcDataFlowEdge(
+ const IRObject& edge_obj, const RpcSequence& sequence,
+ const google::protobuf::MethodDescriptor& to_method) const {
+ auto edge_subs = edge_obj.Subs();
+ FUZZTEST_INTERNAL_CHECK(edge_subs && edge_subs->size() == 3,
+ "Invalid edge format!");
+ auto from_node_id = (*edge_subs)[0].ToCorpus<RpcNodeID>();
+ auto from_field_str = (*edge_subs)[1].ToCorpus<std::string>();
+ auto to_field_str = (*edge_subs)[2].ToCorpus<std::string>();
+ if (!from_node_id || !from_field_str || !to_field_str) {
+ return std::nullopt;
+ }
+ if (*from_node_id >= sequence.size()) return std::nullopt;
+ FieldPath from_field_path = GetFieldPathWithDescriptor(
+ *sequence[*from_node_id].method().output_type(), *from_field_str);
+ FieldPath to_field_path =
+ GetFieldPathWithDescriptor(*to_method.input_type(), *to_field_str);
+ return RpcDataFlowEdge{*from_node_id, from_field_path, to_field_path};
+ }
+
+ std::optional<RpcNode> ParseRpcNodeMethodAndRequest(
+ const IRObject& obj) const {
+ const google::protobuf::DescriptorPool& pool =
+ *google::protobuf::DescriptorPool::generated_pool();
+ auto node_subs = obj.Subs();
+ std::optional<std::string> method_name =
+ (*node_subs)[0].ToCorpus<std::string>();
+ if (!method_name) return std::nullopt;
+ const google::protobuf::MethodDescriptor* method =
+ pool.FindMethodByName(*method_name);
+
+ if (!method) return std::nullopt;
+
+ std::optional<std::string> request_str =
+ (*node_subs)[1].ToCorpus<std::string>();
+
+ std::unique_ptr<google::protobuf::Message> request =
+ absl::WrapUnique(google::protobuf::MessageFactory::generated_factory()
+ ->GetPrototype(method->input_type())
+ ->New());
+ if (!request->ParseFromString(*request_str)) return std::nullopt;
+
+ return RpcNode(*method, std::move(request));
+ }
+
+ absl::flat_hash_set<FieldPath> CollectSatisfiableFields(
+ const RpcPotentialDfgNode& dfg_node,
+ const absl::flat_hash_map<const google::protobuf::MethodDescriptor*,
+ std::vector<RpcNodeID>>& existing_nodes) {
+ absl::flat_hash_set<FieldPath> satisfied_fields;
+ for (const auto& [field_path, source_vec] : dfg_node.GetAllDependencies()) {
+ for (const RpcPotentialDfgNode::PotentialDependencySource& source :
+ source_vec) {
+ if (existing_nodes.contains(source.method)) {
+ satisfied_fields.insert(field_path);
+ break;
+ }
+ }
+ }
+ return satisfied_fields;
+ }
+
+ std::vector<FieldPath> RemoveAlternativeFieldPath(
+ const absl::flat_hash_set<FieldPath>& field_paths, absl::BitGenRef prng) {
+ std::vector<FieldPath> result;
+ std::vector<FieldPath> field_paths_vec(field_paths.begin(),
+ field_paths.end());
+ std::shuffle(field_paths_vec.begin(), field_paths_vec.end(), prng);
+ for (const FieldPath& field_path : field_paths_vec) {
+ bool found_same_one_of_field = false;
+ for (const FieldPath& other_field_path : result) {
+ if (AreOneOfAltearnatives(field_path, other_field_path)) {
+ found_same_one_of_field = true;
+ break;
+ }
+ }
+ if (!found_same_one_of_field) {
+ result.push_back(field_path);
+ }
+ }
+ return result;
+ }
+
+ // Checks whether all the dependencies of `dfg_node` can be satisfied by the
+ // nodes in `existing_nodes`. "Satisfy" means there is at least one
+ // potential dependency of any field in the node that has its source in the
+ // rpc data flow graph.
+ bool DependsOnAndCanBeSatisfiedBy(
+ const RpcPotentialDfgNode& dfg_node,
+ const absl::flat_hash_map<const google::protobuf::MethodDescriptor*,
+ std::vector<RpcNodeID>>& existing_nodes) {
+ if (!dfg_node.HasDependency()) return false;
+ absl::flat_hash_set<FieldPath> satisfied_fields =
+ CollectSatisfiableFields(dfg_node, existing_nodes);
+ // As long as we can satisfy any of the sink field, we consider it as
+ // satisfiable.
+ return !satisfied_fields.empty();
+ }
+
+ // Randomly selects a potential dependency for each dynamic field in `node`
+ // and establishes the dependency between the node and the dependency source.
+ // As we might have multiple calls to the same method in the existing graph
+ // that can serve as the source, we randomly select one of them.
+ void ConnectToExistingNodes(
+ const absl::flat_hash_map<const google::protobuf::MethodDescriptor*,
+ std::vector<RpcNodeID>>
+ existing_nodes,
+ RpcNode& node, absl::BitGenRef prng) {
+ const RpcPotentialDfgNode& dfg_node = abstract_dfg_.GetNode(node.method());
+
+ // If we have multiple satisfiable field paths that are alternative to each
+ // other (e.g., in the same `oneof`), we only need to pick one of them and
+ // establish the concrete dependency.
+ std::vector<FieldPath> satisfied_fields = RemoveAlternativeFieldPath(
+ CollectSatisfiableFields(dfg_node, existing_nodes), prng);
+ FUZZTEST_INTERNAL_CHECK(!satisfied_fields.empty(), "Impossible");
+ for (const FieldPath& sink_field : satisfied_fields) {
+ const auto& source_vec = dfg_node.GetDependencies(sink_field);
+ std::vector<const RpcPotentialDfgNode::PotentialDependencySource*>
+ satisfiable_sources;
+ for (const RpcPotentialDfgNode::PotentialDependencySource& source :
+ source_vec) {
+ if (existing_nodes.contains(source.method))
+ satisfiable_sources.push_back(&source);
+ }
+ if (satisfiable_sources.empty()) continue;
+ FUZZTEST_INTERNAL_CHECK(!satisfiable_sources.empty(),
+ "Some dependencies are not satisified!");
+ const RpcPotentialDfgNode::PotentialDependencySource& chosen_source =
+ *PickRandomElement(satisfiable_sources, prng);
+ const std::vector<RpcNodeID>& candidate_nodes =
+ existing_nodes.find(chosen_source.method)->second;
+ node.AddDependency(
+ RpcDataFlowEdge{PickRandomElement(candidate_nodes, prng),
+ chosen_source.field_path, sink_field});
+ }
+ }
+
+ // Inserts a tail call to the graph. Returns true if a call is inserted.
+ bool InsertTailCall(corpus_type& graph, absl::BitGenRef prng) {
+ absl::flat_hash_map<const google::protobuf::MethodDescriptor*, std::vector<RpcNodeID>>
+ existing_nodes;
+ RpcNodeID max_node_id = 0;
+ for (const auto& [node_id, node] : graph.GetAllNodes()) {
+ existing_nodes[&node.method()].push_back(node_id);
+ max_node_id = std::max(max_node_id, node_id);
+ }
+ RpcNodeID new_node_id = max_node_id + 1;
+
+ // Collect all the methods that can be inserted: They depend on existing
+ // nodes and all their their dependencies can be satisfied by existing
+ // nodes.
+ std::vector<const google::protobuf::MethodDescriptor*> candidate_methods;
+ for (auto& [method, unused_domain] : request_domains_) {
+ const RpcPotentialDfgNode& dfg_node = abstract_dfg_.GetNode(*method);
+ if (DependsOnAndCanBeSatisfiedBy(dfg_node, existing_nodes)) {
+ candidate_methods.push_back(method);
+ }
+ }
+
+ if (candidate_methods.empty()) return false;
+
+ const google::protobuf::MethodDescriptor& method_to_insert =
+ *PickRandomElement(candidate_methods, prng);
+ auto& request_domain = GetRequestDomain(method_to_insert);
+ RpcNode node_to_insert(method_to_insert,
+ request_domain.GetValue(request_domain.Init(prng)));
+ ConnectToExistingNodes(existing_nodes, node_to_insert, prng);
+ graph.AddNode(new_node_id, std::move(node_to_insert));
+ return true;
+ }
+
+ bool DeleteTailCall(corpus_type& graph, absl::BitGenRef prng) {
+ if (graph.NodeNum() == 1) return false;
+ // Non-tail nodes (nodes with dependents).
+ absl::flat_hash_set<RpcNodeID> non_tail_nodes;
+ for (const auto& [node_id, node] : graph.GetAllNodes()) {
+ for (const auto& edge : node.dependencies()) {
+ non_tail_nodes.insert(edge.from_node_id);
+ }
+ }
+ FUZZTEST_INTERNAL_CHECK(non_tail_nodes.size() != graph.NodeNum(),
+ "Loop dependence!");
+ std::vector<RpcNodeID> tail_nodes;
+ for (const auto& [node_id, node] : graph.GetAllNodes()) {
+ if (!non_tail_nodes.contains(node_id)) tail_nodes.push_back(node_id);
+ }
+
+ graph.RemoveNode(PickRandomElement(tail_nodes, prng));
+ return true;
+ }
+
+ bool MutateStaticField(corpus_type& graph, absl::BitGenRef prng,
+ bool only_shrink) {
+ RpcNode& selected_node = std::next(graph.GetAllNodes().begin(),
+ ChooseOffset(graph.NodeNum(), prng))
+ ->second;
+ auto& request_domain = GetRequestDomain(selected_node.method());
+ auto request_domain_corpus_value =
+ request_domain.FromValue(selected_node.request());
+ FUZZTEST_INTERNAL_CHECK(request_domain_corpus_value.has_value(),
+ "Invalid proto!");
+ request_domain.Mutate(*request_domain_corpus_value, prng, only_shrink);
+ selected_node.request().CopyFrom(
+ *request_domain.GetValue(*request_domain_corpus_value));
+ return true;
+ }
+
+ ProtobufDomainUntypedImpl<google::protobuf::Message>& GetRequestDomain(
+ const google::protobuf::MethodDescriptor& method) {
+ auto iter = request_domains_.find(&method);
+ FUZZTEST_INTERNAL_CHECK(iter != request_domains_.end(), "Invalid method!");
+ return iter->second;
+ }
+
+ // We keep a domain for the request type for each method in the service to
+ // generate the request message in the RpcNode.
+ // TODO(changochen): We are now directly storing the concrete requests in
+ // RpcNode, and serialize them as string. We should use the corpus_type of the
+ // proto domain in the future to allow proto customization.
+ absl::flat_hash_map<const google::protobuf::MethodDescriptor*,
+ ProtobufDomainUntypedImpl<google::protobuf::Message>>
+ request_domains_;
+ std::vector<const google::protobuf::MethodDescriptor*> all_methods_;
+ RpcPotentialDataFlowGraph abstract_dfg_;
+};
+} // namespace internal
+
+using internal::RpcSequence;
+
+inline absl::Status ExecuteRpcSequence(RpcStub& stub, RpcSequence& sequence) {
+ return internal::RpcExecutor(&stub).Execute(sequence);
+}
+
+template <typename ServiceT>
+inline auto RpcSession() {
+ return internal::Lazy<internal::RpcSessionImpl<ServiceT>>();
+}
+
+// Creates an RpcSession with a service name factory. This is for when the
+// service name is unknown at compile time (i.e., getting the name from the
+// generated pool.). This factory function will be called after main().
+// If the service name/type is known at compile time, please use the template
+// version above: RpcSession<ServiceT>().
+inline auto RpcSessionOf(
+ std::function<absl::string_view()> service_name_factory) {
+ FUZZTEST_INTERNAL_CHECK(service_name_factory, "Invalid service factory!");
+ return internal::Lazy<internal::RpcSessionImpl<>,
+ std::function<absl::string_view()>>(
+ std::move(service_name_factory));
+}
+
+} // namespace fuzztest
+#endif // FUZZTEST_RPC_FUZZING_RPC_SESSION_H_
diff --git a/rpc_fuzzing/rpc_session_test.cc b/rpc_fuzzing/rpc_session_test.cc
new file mode 100644
index 0000000..3b07361
--- /dev/null
+++ b/rpc_fuzzing/rpc_session_test.cc
@@ -0,0 +1,501 @@
+#include "./rpc_fuzzing/rpc_session.h"
+
+#include <cstdint>
+#include <optional>
+#include <variant>
+
+#include "google/protobuf/descriptor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/random/random.h"
+#include "./domain_tests/domain_testing.h"
+#include "./fuzztest/internal/serialization.h"
+#include "./rpc_fuzzing/proto_field_path.h"
+#include "./rpc_fuzzing/rpc_potential_dfg.h"
+#include "./rpc_fuzzing/rpc_sequence.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h"
+
+namespace fuzztest::internal {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Conditional;
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::FieldsAre;
+using ::testing::IsTrue;
+using ::testing::NanSensitiveDoubleEq;
+using ::testing::Optional;
+using ::testing::Pair;
+using ::testing::ResultOf;
+using ::testing::UnorderedElementsAre;
+using ::testing::UnorderedElementsAreArray;
+using ::testing::VariantWith;
+
+template <typename T>
+auto ValueIs(const T& v) {
+ if constexpr (std::is_same_v<T, double>) {
+ return FieldsAre(VariantWith<double>(NanSensitiveDoubleEq(v)));
+ } else {
+ return FieldsAre(VariantWith<T>(v));
+ }
+}
+
+template <typename... T>
+auto SubsAre(const T&... v) {
+ return FieldsAre(VariantWith<std::vector<IRObject>>(ElementsAre(v...)));
+}
+
+class RpcDomainTest : public ::testing::Test {
+ public:
+ RpcDomainTest()
+ : pool_(*ABSL_DIE_IF_NULL(google::protobuf::DescriptorPool::generated_pool())),
+ mini_blogger_serivce_(*ABSL_DIE_IF_NULL(
+ pool_.FindServiceByName("fuzztest.internal.MiniBlogger"))),
+ log_in_user_method_(*ABSL_DIE_IF_NULL(
+ mini_blogger_serivce_.FindMethodByName("LogInUser"))),
+ log_out_user_method_(*ABSL_DIE_IF_NULL(
+ mini_blogger_serivce_.FindMethodByName("LogOutUser"))),
+ get_user_posts_method_(*ABSL_DIE_IF_NULL(
+ mini_blogger_serivce_.FindMethodByName("GetUserPosts"))) {}
+
+ protected:
+ RpcNode GetLogInUserNode() const {
+ return RpcNode(log_in_user_method_, std::make_unique<LogInUserRequest>());
+ }
+
+ RpcNode GetLogOutUserNode() const {
+ return RpcNode(log_out_user_method_, std::make_unique<LogOutUserRequest>());
+ }
+
+ RpcNode GetGetUserPostsNode() const {
+ return RpcNode(get_user_posts_method_,
+ std::make_unique<GetUserPostsRequest>());
+ }
+
+ RpcDataFlowGraph CreateRandomGraph(RpcSessionImpl<MiniBlogger>& rpc_domain) {
+ absl::BitGen gen;
+ RpcDataFlowGraph graph = rpc_domain.Init(gen);
+ for (int i = 0; i < 100; ++i) {
+ rpc_domain.Mutate(graph, gen, false);
+ }
+ // Generate an order.
+ (void)graph.GetOrderedNodeIds();
+ return graph;
+ }
+
+ const google::protobuf::DescriptorPool& pool_;
+ const google::protobuf::ServiceDescriptor& mini_blogger_serivce_;
+ const google::protobuf::MethodDescriptor& log_in_user_method_;
+ const google::protobuf::MethodDescriptor& log_out_user_method_;
+ const google::protobuf::MethodDescriptor& get_user_posts_method_;
+};
+
+TEST_F(RpcDomainTest, InitGeneratesARandomSingleCallWithoutDependencies) {
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+ for (int i = 0; i < 100; ++i) {
+ RpcDataFlowGraph value = rpc_domain.Init(bitgen);
+ EXPECT_THAT(value.GetAllNodes(),
+ ElementsAre(Pair(0, ResultOf(
+ [](const RpcNode& node) {
+ return node.dependencies().empty();
+ },
+ IsTrue()))));
+ }
+}
+
+TEST_F(RpcDomainTest, MutationOnlyInsertsNodesThatDependOnExistingNodes) {
+ constexpr std::string_view kMethodsDependingLogInUser[] = {"GetUserPosts",
+ "LogOutUser"};
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, GetLogInUserNode());
+ absl::flat_hash_set<std::string> inserted_methods;
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+ for (int j = 0; j < 100; ++j) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, false);
+ if (mutated_graph.NodeNum() != 2) continue;
+ // Mutated through insertion.
+ inserted_methods.insert(mutated_graph.GetSequence()[1].method().name());
+ }
+ EXPECT_THAT(inserted_methods,
+ UnorderedElementsAreArray(kMethodsDependingLogInUser));
+}
+
+TEST_F(RpcDomainTest, InsertedNodesHaveDependenciesOnExistingNodes) {
+ constexpr RpcNodeID kFromNodeID = 0;
+ RpcDataFlowGraph graph;
+ graph.AddNode(kFromNodeID, GetLogInUserNode());
+
+ bool insertion_triggerred = false;
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+ for (int j = 0; j < 100; ++j) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, false);
+ if (mutated_graph.NodeNum() != 2) continue;
+ // Mutated through insertion.
+ insertion_triggerred = true;
+
+ const RpcNode inserted_node = mutated_graph.GetSequence()[1];
+ EXPECT_THAT(
+ inserted_node.dependencies(),
+ ElementsAre(FieldsAre(
+ kFromNodeID, GetFieldPath<LogInUserResponse>("session_id"),
+ Conditional(inserted_node.method().name() == "LogOutUser",
+ AnyOf(GetFieldPath<LogOutUserRequest>(
+ "log_out_info.session_info.session_id"),
+ GetFieldPath<LogOutUserRequest>(
+ "log_out_info.session_id")),
+ GetFieldPath<GetUserPostsRequest>("session_id")))));
+ }
+
+ EXPECT_TRUE(insertion_triggerred);
+}
+
+TEST_F(RpcDomainTest, InsertedNodeSelectsRandomAlternativeDepWithinSameOneOf) {
+ constexpr RpcNodeID kFromNodeID = 0;
+ RpcDataFlowGraph graph;
+ graph.AddNode(kFromNodeID, GetLogInUserNode());
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::flat_hash_set<std::string> sink_fields;
+ absl::BitGen bitgen;
+ for (int j = 0; j < 200; ++j) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, false);
+ if (mutated_graph.NodeNum() != 2 ||
+ mutated_graph.GetSequence()[1].method().name() != "LogOutUser")
+ continue;
+ // Either `log_out_info.session_info.session_id` or
+ // `log_out_info.session_id` will be selected as the sink.
+ EXPECT_EQ(mutated_graph.GetSequence()[1].dependencies().size(), 1);
+ sink_fields.insert(
+ mutated_graph.GetSequence()[1].dependencies()[0].to_field.ToString());
+ }
+ EXPECT_THAT(sink_fields,
+ UnorderedElementsAre("log_out_info.session_info.session_id",
+ "log_out_info.session_id"));
+}
+
+TEST_F(RpcDomainTest, MutationOnlyDeletesTailNodes) {
+ constexpr std::string_view kMethodsDependingLogInUser[] = {"GetUserPosts",
+ "LogOutUser"};
+ // Node id 0.
+ RpcNode log_in_user_node = GetLogInUserNode();
+ // Node id 1.
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ // Node id 2.
+ RpcNode get_user_posts_node = GetGetUserPostsNode();
+
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ get_user_posts_node.AddDependency(
+ RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<GetUserPostsRequest>("session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+ graph.AddNode(2, get_user_posts_node);
+
+ absl::flat_hash_set<std::string> deleted_methods;
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+
+ for (int j = 0; j < 100; ++j) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, false);
+ if (mutated_graph.NodeNum() != 2) continue;
+ // Mutated through deletion.
+ if (mutated_graph.GetSequence()[1].method().name() == "LogOutUser") {
+ deleted_methods.insert("GetUserPosts");
+ } else {
+ deleted_methods.insert("LogOutUser");
+ }
+ }
+ EXPECT_THAT(deleted_methods,
+ UnorderedElementsAreArray(kMethodsDependingLogInUser));
+}
+
+TEST_F(RpcDomainTest, OnlyShrinkMutationDecreasesNodeNumOrRequest) {
+ RpcNode log_in_user_node = GetLogInUserNode();
+ auto get_user_post_request = std::make_unique<GetUserPostsRequest>();
+ get_user_post_request->set_max_posts(100);
+ RpcNode log_out_user_node(get_user_posts_method_,
+ std::move(get_user_post_request));
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+ bool static_field_shrinked = false;
+ for (int i = 0; i < 100; ++i) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, true);
+
+ if (mutated_graph.NodeNum() == 2) {
+ const RpcNode& node = mutated_graph.GetNode(1);
+ int new_max_posts = node.request().GetReflection()->GetInt32(
+ node.request(),
+ node.request().GetDescriptor()->FindFieldByName("max_posts"));
+ EXPECT_LE(new_max_posts, 100);
+ if (new_max_posts < 100) {
+ static_field_shrinked = true;
+ }
+ } else {
+ EXPECT_EQ(mutated_graph.NodeNum(), 1);
+ }
+ }
+ EXPECT_TRUE(static_field_shrinked);
+}
+
+TEST_F(RpcDomainTest, MutateEventuallyChangesStaticFieldInRequest) {
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, GetLogInUserNode());
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ absl::BitGen bitgen;
+ bool static_field_changed = false;
+ for (int i = 0; i < 100; ++i) {
+ RpcDataFlowGraph mutated_graph = graph;
+ rpc_domain.Mutate(mutated_graph, bitgen, false);
+ if (mutated_graph.NodeNum() == graph.NodeNum() &&
+ mutated_graph.GetNode(0).method().name() ==
+ graph.GetNode(0).method().name()) {
+ // Mutated through static field mutation.
+ if (!google::protobuf::util::MessageDifferencer::Equals(
+ mutated_graph.GetNode(0).request(), graph.GetNode(0).request())) {
+ static_field_changed = true;
+ break;
+ }
+ }
+ }
+
+ EXPECT_TRUE(static_field_changed);
+}
+
+TEST_F(RpcDomainTest, FromValueTransformsSequenceToRpcGraph) {
+ RpcSequence sequence = {GetLogInUserNode(), GetLogOutUserNode(),
+ GetGetUserPostsNode()};
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ std::optional<RpcDataFlowGraph> graph = rpc_domain.FromValue(sequence);
+
+ ASSERT_TRUE(graph.has_value());
+
+ ASSERT_EQ(graph->NodeNum(), 3);
+ std::vector<RpcNode> all_nodes{graph->GetNode(0), graph->GetNode(1),
+ graph->GetNode(2)};
+ EXPECT_EQ(all_nodes, sequence);
+}
+
+TEST_F(RpcDomainTest, ParseCorpusInClearTextFromReturnsRpcDataFlowGraph) {
+ constexpr absl::string_view kObjectText =
+ R"(FUZZTESTv1
+ sub {
+ sub { s: "fuzztest.internal.MiniBlogger.LogInUser" }
+ sub { s: "" }
+ sub { }
+ }
+ sub {
+ sub { s: "fuzztest.internal.MiniBlogger.LogOutUser" }
+ sub { s: "" }
+ sub {
+ sub {
+ sub { i: 0 }
+ sub { s: "session_id" }
+ sub { s: "log_out_info.session_info.session_id" }
+ }
+ }
+ })";
+ std::optional<IRObject> obj = IRObject::FromString(kObjectText);
+ ASSERT_TRUE(obj.has_value());
+
+ RpcNode log_in_user_node = GetLogInUserNode();
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ RpcDataFlowEdge edge = {
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")};
+ log_out_user_node.AddDependency(edge);
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ auto obj2 = rpc_domain.SerializeCorpus(graph);
+ EXPECT_EQ(obj->ToString(), obj2.ToString());
+}
+
+TEST_F(RpcDomainTest, ParseCorpusReturnsRpcDataFlowGraph) {
+ RpcNode log_in_user_node = GetLogInUserNode();
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ RpcDataFlowEdge edge = {
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")};
+ log_out_user_node.AddDependency(edge);
+
+ /* Set up log_in_user_node */
+ IRObject log_in_user_node_obj;
+ auto& log_in_user_node_subs = log_in_user_node_obj.MutableSubs();
+ log_in_user_node_subs.push_back(
+ IRObject::FromCorpus(log_in_user_node.method().full_name()));
+ log_in_user_node_subs.push_back(
+ IRObject::FromCorpus(log_in_user_node.request().SerializeAsString()));
+ log_in_user_node_subs.push_back(IRObject{});
+
+ /* Set up log_out_user_node*/
+ IRObject log_out_user_node_obj;
+ auto& log_out_user_node_subs = log_out_user_node_obj.MutableSubs();
+ log_out_user_node_subs.push_back(
+ IRObject::FromCorpus(log_out_user_node.method().full_name()));
+ log_out_user_node_subs.push_back(
+ IRObject::FromCorpus(log_out_user_node.request().SerializeAsString()));
+ IRObject edge_obj;
+ auto& edge_subs = edge_obj.MutableSubs();
+ edge_subs.push_back(IRObject::FromCorpus(edge.from_node_id));
+ edge_subs.push_back(IRObject::FromCorpus(edge.from_field.ToString()));
+ edge_subs.push_back(IRObject::FromCorpus(edge.to_field.ToString()));
+ IRObject edges;
+ edges.MutableSubs().push_back(edge_obj);
+ log_out_user_node_subs.push_back(edges);
+
+ IRObject obj;
+ auto& subs = obj.MutableSubs();
+ subs.push_back(log_in_user_node_obj);
+ subs.push_back(log_out_user_node_obj);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ auto graph = rpc_domain.ParseCorpus(obj);
+ ASSERT_TRUE(graph.has_value());
+ EXPECT_EQ(graph->NodeNum(), 2);
+ EXPECT_EQ(graph->GetNode(0), log_in_user_node);
+ EXPECT_EQ(graph->GetNode(1), log_out_user_node);
+}
+
+TEST_F(RpcDomainTest, SerializesCorpusReturnsIRObjectOfSpecificStructure) {
+ RpcNode log_in_user_node = GetLogInUserNode();
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ IRObject obj = rpc_domain.SerializeCorpus(graph);
+
+ EXPECT_THAT(
+ obj,
+ SubsAre(
+ /*log_in_user_node*/
+ SubsAre(ValueIs<std::string>(log_in_user_method_.full_name()),
+ ValueIs<std::string>(
+ log_in_user_node.request().SerializeAsString()),
+ ValueIs<std::monostate>({})),
+ /*log_out_user_node*/
+ SubsAre(
+ ValueIs<std::string>(log_out_user_method_.full_name()),
+ ValueIs<std::string>(
+ log_out_user_node.request().SerializeAsString()),
+ /*dependencies*/
+ SubsAre(SubsAre(
+ ValueIs<std::uint64_t>(0) /*from_node_id*/,
+ ValueIs<std::string>("session_id") /*from_field*/,
+ ValueIs<
+ std::string>(/*to_field*/
+ "log_out_info.session_info.session_id"))))));
+}
+
+TEST_F(RpcDomainTest, SerializesCorpusAndParsesCorpusReturnTheSameObject) {
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ for (int i = 0; i < 100; ++i) {
+ RpcDataFlowGraph graph = CreateRandomGraph(rpc_domain);
+ EXPECT_THAT(rpc_domain.ParseCorpus(rpc_domain.SerializeCorpus(graph)),
+ Optional(Eq(graph)));
+ }
+}
+
+TEST_F(RpcDomainTest, ValidRpcNodeShouldOnlyDependOnPreviousNodes) {
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ EXPECT_THAT(
+ rpc_domain.ValidateCorpusValue(graph),
+ IsInvalid(
+ "The dependencies should only come from previously executed nodes."));
+}
+
+TEST_F(RpcDomainTest, ValidRpcNodeDependencyMatchesPotentialDependency) {
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<RegisterUserResponse>("success"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, GetLogInUserNode());
+ graph.AddNode(1, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ EXPECT_THAT(
+ rpc_domain.ValidateCorpusValue(graph),
+ IsInvalid(
+ "The dependency is not defined in the potential data flow graph."));
+}
+
+TEST_F(RpcDomainTest,
+ ValidRpcNodeShouldHaveAtMostOneDependencyForEachDynamicField) {
+ RpcNode log_in_user_node = GetLogInUserNode();
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ EXPECT_THAT(
+ rpc_domain.ValidateCorpusValue(graph),
+ IsInvalid("One sink field should have at most one concrete dependency!"));
+}
+
+TEST_F(RpcDomainTest, ValidRpcNodePassValidationTest) {
+ RpcNode log_in_user_node = GetLogInUserNode();
+ RpcNode log_out_user_node = GetLogOutUserNode();
+ log_out_user_node.AddDependency(RpcDataFlowEdge{
+ 0, GetFieldPath<LogInUserResponse>("session_id"),
+ GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")});
+
+ RpcDataFlowGraph graph;
+ graph.AddNode(0, log_in_user_node);
+ graph.AddNode(1, log_out_user_node);
+
+ RpcSessionImpl<MiniBlogger> rpc_domain;
+ EXPECT_OK(rpc_domain.ValidateCorpusValue(graph));
+}
+} // namespace
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_stub.h b/rpc_fuzzing/rpc_stub.h
new file mode 100644
index 0000000..3969912
--- /dev/null
+++ b/rpc_fuzzing/rpc_stub.h
@@ -0,0 +1,25 @@
+#ifndef FUZZTEST_RPC_FUZZING_RPC_STUB_H_
+#define FUZZTEST_RPC_FUZZING_RPC_STUB_H_
+
+#include <memory>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "absl/status/statusor.h"
+
+namespace fuzztest {
+
+// RPC service stub interface.
+class RpcStub {
+ public:
+ virtual ~RpcStub() = default;
+ // Calls the RPC method described by `method_descriptor`, sending it `request`
+ // and returning the obtained response.
+ virtual absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod(
+ const google::protobuf::MethodDescriptor& method_descriptor,
+ const google::protobuf::Message& request) = 0;
+};
+
+} // namespace fuzztest
+
+#endif // FUZZTEST_RPC_FUZZING_RPC_STUB_H_
diff --git a/rpc_fuzzing/scaffolding_stub.cc b/rpc_fuzzing/scaffolding_stub.cc
new file mode 100644
index 0000000..b008bb7
--- /dev/null
+++ b/rpc_fuzzing/scaffolding_stub.cc
@@ -0,0 +1,58 @@
+#include "./rpc_fuzzing/scaffolding_stub.h"
+
+#include <memory>
+#include <string>
+
+#include "net/base/sslconstant.h"
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "net/rpc/anonymous-stub.h"
+#include "net/rpc2/rpc2.h"
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+
+namespace fuzztest {
+
+absl::StatusOr<std::unique_ptr<google::protobuf::Message>> ScaffoldingStub::CallMethod(
+ const google::protobuf::MethodDescriptor& method_descriptor,
+ const google::protobuf::Message& request) {
+ const google::protobuf::Message* response_prototype =
+ google::protobuf::MessageFactory::generated_factory()->GetPrototype(
+ method_descriptor.output_type());
+ if (!response_prototype) {
+ return absl::InternalError(
+ absl::StrCat("Cannot find prototype for ",
+ method_descriptor.output_type()->full_name(),
+ " in the generated proto MessageFactory"));
+ }
+ std::unique_ptr<google::protobuf::Message> response =
+ absl::WrapUnique(response_prototype->New());
+ if (!response) {
+ return absl::InternalError(
+ absl::StrCat("Cannot create a new instance of response type ",
+ method_descriptor.output_type()->full_name()));
+ }
+
+ RPC rpc;
+ // AnonymousStub doesn't have access to the method options, so we need to set
+ // this manually.
+ if (method_descriptor.options().has_security_level()) {
+ rpc.set_requested_security_level(net_base::SSLSecurityLevel(
+ method_descriptor.options().security_level()));
+ }
+ const google::protobuf::ServiceDescriptor& service_descriptor =
+ *method_descriptor.service();
+ const std::string method_name = absl::StrCat("/", service_descriptor.name(),
+ ".", method_descriptor.name());
+ stub_->Send(/*package_name=*/"", method_name, &rpc, &request, response.get(),
+ /* done= */ nullptr);
+ rpc.WaitRespectingFiberCancellation();
+ if (rpc.util_status().ok()) {
+ return response;
+ }
+ return rpc.util_status();
+}
+
+} // namespace fuzztest
diff --git a/rpc_fuzzing/scaffolding_stub.h b/rpc_fuzzing/scaffolding_stub.h
new file mode 100644
index 0000000..807c8d6
--- /dev/null
+++ b/rpc_fuzzing/scaffolding_stub.h
@@ -0,0 +1,31 @@
+#ifndef FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_
+#define FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_
+
+#include <memory>
+#include <utility>
+
+#include "net/rpc/anonymous-stub.h"
+#include "net/rpc2/rpc2.h"
+#include "absl/status/statusor.h"
+#include "./rpc_fuzzing/rpc_stub.h"
+
+namespace fuzztest {
+
+class ScaffoldingStub : public RpcStub {
+ public:
+ ScaffoldingStub(absl::string_view socket)
+ : stub_(std::make_unique<AnonymousStub>(
+ rpc2::CreateClientChannel(socket))) {}
+ ScaffoldingStub(std::unique_ptr<AnonymousStub> stub)
+ : stub_(std::move(stub)) {}
+ absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod(
+ const google::protobuf::MethodDescriptor& method_descriptor,
+ const google::protobuf::Message& request) override;
+
+ private:
+ std::unique_ptr<AnonymousStub> stub_;
+};
+
+} // namespace fuzztest
+
+#endif // FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_
diff --git a/rpc_fuzzing/testdata/BUILD b/rpc_fuzzing/testdata/BUILD
new file mode 100644
index 0000000..ba59853
--- /dev/null
+++ b/rpc_fuzzing/testdata/BUILD
@@ -0,0 +1,54 @@
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
+
+package(default_visibility = [
+ "@com_google_fuzztest//e2e_tests/testdata:__pkg__",
+ "@com_google_fuzztest//fuzztest:__subpackages__",
+ "@com_google_fuzztest//rpc_fuzzing:__pkg__",
+])
+
+proto_library(
+ name = "mini_blogger",
+ srcs = ["mini_blogger.proto"],
+ target_compatible_with = select({
+ "@com_google_fuzztest//rpc_fuzzing:rpc_fuzzing_is_enabled": [],
+ "//conditions:default": ["@platforms//:incompatible"],
+ }),
+)
+
+cc_proto_library(
+ name = "mini_blogger_cc_proto",
+ deps = [":mini_blogger"],
+)
+
+cc_grpc_library(
+ name = "mini_blogger_cc_grpc_proto",
+ srcs = [":mini_blogger"],
+ grpc_only = True,
+ deps = [":mini_blogger_cc_proto"],
+)
+
+cc_library(
+ name = "mini_blogger_grpc_service",
+ srcs = ["grpc/mini_blogger_service.cc"],
+ hdrs = ["grpc/mini_blogger_service.h"],
+ deps = [
+ ":mini_blogger_cc_grpc_proto",
+ ":mini_blogger_cc_proto",
+ ":mini_blogger_context",
+ "@com_github_grpc_grpc//:grpc++",
+ ],
+)
+
+cc_library(
+ name = "mini_blogger_context",
+ srcs = ["mini_blogger_context.cc"],
+ hdrs = ["mini_blogger_context.h"],
+ deps = [
+ ":mini_blogger_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc b/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc
new file mode 100644
index 0000000..a49bd61
--- /dev/null
+++ b/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc
@@ -0,0 +1,36 @@
+// A grpc server that can run locally for testing.
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "base/init_google.h"
+#include "absl/base/log_severity.h"
+#include "absl/flags/flag.h"
+#include "absl/log/globals.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h"
+#include "grpcpp//security/server_credentials.h"
+#include "grpcpp//server.h"
+#include "grpcpp//server_builder.h"
+
+ABSL_FLAG(int32_t, port, 5000, "port to listen on");
+
+int main(int argc, char** argv) {
+ InitGoogle(argv[0], &argc, &argv, /*remove_flags=*/true);
+
+ absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
+
+ std::string server_address = absl::StrCat("[::]:", absl::GetFlag(FLAGS_port));
+
+ fuzztest::internal::MiniBloggerGrpcService mini_blogger;
+ grpc::ServerBuilder builder;
+ builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
+ builder.RegisterService(&mini_blogger);
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ LOG(INFO) << "Server listening on " << server_address;
+
+ server->Wait();
+
+ return 0;
+}
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc b/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc
new file mode 100644
index 0000000..d6cc705
--- /dev/null
+++ b/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc
@@ -0,0 +1,36 @@
+#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h"
+
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+#include "grpcpp//server_context.h"
+#include "grpcpp//support/status.h"
+
+namespace fuzztest::internal {
+
+grpc::Status MiniBloggerGrpcService::RegisterUser(
+ grpc::ServerContext*, const RegisterUserRequest* request,
+ RegisterUserResponse* response) {
+ mini_blogger_context_.RegisterUser(*request, *response);
+ return grpc::Status::OK;
+}
+
+grpc::Status MiniBloggerGrpcService::LogInUser(grpc::ServerContext*,
+ const LogInUserRequest* request,
+ LogInUserResponse* response) {
+ mini_blogger_context_.LogInUser(*request, *response);
+ return grpc::Status::OK;
+}
+
+grpc::Status MiniBloggerGrpcService::GetUserPosts(
+ grpc::ServerContext* unused_service_context,
+ const GetUserPostsRequest* request, GetUserPostsResponse* response) {
+ mini_blogger_context_.GetUserPosts(*request, *response);
+ return grpc::Status::OK;
+}
+
+grpc::Status MiniBloggerGrpcService::LogOutUser(
+ grpc::ServerContext* unused_service_context,
+ const LogOutUserRequest* request, LogOutUserResponse* response) {
+ mini_blogger_context_.LogOutUser(*request, *response);
+ return grpc::Status::OK;
+}
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_service.h b/rpc_fuzzing/testdata/grpc/mini_blogger_service.h
new file mode 100644
index 0000000..d627b37
--- /dev/null
+++ b/rpc_fuzzing/testdata/grpc/mini_blogger_service.h
@@ -0,0 +1,34 @@
+#ifndef FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_
+#define FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_
+
+#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+#include "./rpc_fuzzing/testdata/mini_blogger_context.h"
+#include "grpcpp//support/status.h"
+
+namespace fuzztest::internal {
+
+// The gRpc server implementation of MiniBlogger.
+class MiniBloggerGrpcService : public MiniBlogger::Service {
+ public:
+ grpc::Status RegisterUser(grpc::ServerContext*,
+ const RegisterUserRequest* request,
+ RegisterUserResponse* response) override;
+
+ grpc::Status LogInUser(grpc::ServerContext*, const LogInUserRequest* request,
+ LogInUserResponse* response) override;
+
+ grpc::Status GetUserPosts(grpc::ServerContext*,
+ const GetUserPostsRequest* request,
+ GetUserPostsResponse* response) override;
+
+ grpc::Status LogOutUser(grpc::ServerContext*,
+ const LogOutUserRequest* request,
+ LogOutUserResponse* response) override;
+
+ private:
+ MiniBloggerContext mini_blogger_context_;
+};
+} // namespace fuzztest::internal
+
+#endif // FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_
diff --git a/rpc_fuzzing/testdata/mini_blogger.proto b/rpc_fuzzing/testdata/mini_blogger.proto
new file mode 100644
index 0000000..50c2da0
--- /dev/null
+++ b/rpc_fuzzing/testdata/mini_blogger.proto
@@ -0,0 +1,127 @@
+syntax = "proto3";
+
+package fuzztest.internal;
+
+service MiniBlogger {
+ // Register a user.
+ rpc RegisterUser(RegisterUserRequest) returns (RegisterUserResponse);
+ // Log in a user and get a session id.
+ rpc LogInUser(LogInUserRequest) returns (LogInUserResponse);
+ // Get a user's posts by session id.
+ rpc GetUserPosts(GetUserPostsRequest) returns (GetUserPostsResponse);
+ // Log out the user (invaliding the session id).
+ rpc LogOutUser(LogOutUserRequest) returns (LogOutUserResponse);
+}
+
+message RegisterUserRequest {
+ string user_name = 1;
+ string email = 2;
+ string password = 3;
+}
+
+message RegisterUserResponse {
+ bool success = 1;
+}
+
+message LogInUserRequest {
+ string name = 1;
+ string password = 2;
+}
+
+message LogInUserResponse {
+ bool success = 1;
+ // A unique id for the user logging in session.
+ int64 session_id = 2;
+}
+
+enum SortingOrder {
+ ASCENDING = 0;
+ DESCENDING = 1;
+}
+
+enum SortedBy {
+ DATE = 0;
+ POPULARITY = 1;
+}
+
+// For testing only.
+message GetUserPostsOptionsWithDifferentFieldNames {
+ SortingOrder ordering = 1;
+}
+
+message GetUserPostsOptions {
+ SortingOrder order = 1;
+ SortedBy sort_by = 2;
+}
+
+message GetUserPostsRequest {
+ int64 session_id = 1;
+ int32 max_posts = 2; // Maximum number of posts returned.
+ GetUserPostsOptions options = 3;
+}
+
+message GetUserPostsResponse {
+ repeated string posts = 1;
+}
+
+// For testing dependency established at inner fields.
+message LogOutInfo {
+ oneof log_out_info {
+ SessionInfo session_info = 1;
+ int64 session_id = 2;
+ }
+}
+
+message SessionInfo {
+ int64 session_id = 1;
+}
+
+message LogOutUserRequest {
+ LogOutInfo log_out_info = 1;
+}
+
+message LogOutUserResponse {
+ bool success = 1;
+}
+
+// This is only for testing that the util of collecting all fields in a message
+// won't break with recursive messages.
+message RecursiveNode {
+ int64 value = 1;
+ repeated RecursiveNode children = 2;
+}
+
+// For testing only.
+message OneOfMessage {
+ oneof oneof1 {
+ DummyMessage oneof1_field1 = 1;
+ int64 oneof1_field2 = 2;
+ }
+ oneof oneof2 {
+ DummyMessage oneof2_field1 = 3;
+ int64 oneof2_field2 = 4;
+ }
+}
+
+message DummyMessage {
+ int64 v1 = 1;
+ int64 v2 = 2;
+}
+
+message InnerRepeated {
+ int64 field2 = 1;
+}
+
+message MessageWithSingleInnerRepeated {
+ InnerRepeated repeated_msg_field = 1;
+}
+
+message MessageContainingRepeatedFields {
+ repeated int64 field = 1;
+ repeated InnerRepeated repeated_msg_field = 2;
+}
+
+message MessageNotContainingRepeatedFields {
+ int64 field = 1;
+ int64 field2 = 2;
+}
diff --git a/rpc_fuzzing/testdata/mini_blogger_context.cc b/rpc_fuzzing/testdata/mini_blogger_context.cc
new file mode 100644
index 0000000..d46863b
--- /dev/null
+++ b/rpc_fuzzing/testdata/mini_blogger_context.cc
@@ -0,0 +1,102 @@
+#include "./rpc_fuzzing/testdata/mini_blogger_context.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+
+namespace fuzztest::internal {
+
+void MiniBloggerContext::RegisterUser(const RegisterUserRequest& request,
+ RegisterUserResponse& response) {
+ bool result = RegisterUserHandler(request.user_name(), request.email(),
+ request.password());
+ response.set_success(result);
+}
+
+void MiniBloggerContext::LogInUser(const LogInUserRequest& request,
+ LogInUserResponse& response) {
+ std::optional<std::int64_t> result =
+ LoginUserHandler(request.name(), request.password());
+ if (!result) {
+ response.set_success(false);
+ } else {
+ response.set_success(true);
+ response.set_session_id(result.value());
+ }
+}
+
+void MiniBloggerContext::GetUserPosts(const GetUserPostsRequest& request,
+ GetUserPostsResponse& response) {
+ for (const std::string& post : GetUserPostsHandler(request.session_id())) {
+ response.add_posts(post);
+ }
+}
+
+void MiniBloggerContext::LogOutUser(const LogOutUserRequest& request,
+ LogOutUserResponse& response) {
+ if (!request.has_log_out_info() ||
+ !request.log_out_info().has_session_info()) {
+ response.set_success(false);
+ } else {
+ bool result =
+ LogOutUserHandler(request.log_out_info().session_info().session_id());
+ response.set_success(result);
+ }
+}
+bool MiniBloggerContext::RegisterUserHandler(absl::string_view name,
+ absl::string_view email,
+ absl::string_view passwd) {
+ if (name.empty() || email.empty() || passwd.empty()) return false;
+ absl::WriterMutexLock l(&lock_);
+ if (users_.contains(name)) return false;
+ users_.emplace(std::string(name),
+ MiniBloggerUser{std::string(name), std::string(email),
+ std::string(passwd)});
+ return true;
+}
+
+std::optional<std::int64_t> MiniBloggerContext::LoginUserHandler(
+ absl::string_view name, absl::string_view passwd) {
+ absl::WriterMutexLock l(&lock_);
+ if (!name.empty()) {
+ if (!users_.contains(name)) return std::nullopt;
+ if (users_[name].passwd != passwd) return std::nullopt;
+ }
+ absl::BitGen gen;
+ std::int64_t sid =
+ absl::Uniform<int64_t>(gen, 0, std::numeric_limits<int64_t>::max());
+ active_session_ids_.insert(sid);
+ return sid;
+}
+
+bool MiniBloggerContext::LogOutUserHandler(std::int64_t sid) {
+ absl::WriterMutexLock l(&lock_);
+ if (!active_session_ids_.contains(sid)) return false;
+ inactive_session_ids_.insert(sid);
+ active_session_ids_.erase(sid);
+ return true;
+}
+
+std::vector<std::string> MiniBloggerContext::GetUserPostsHandler(
+ std::int64_t sid) {
+ std::vector<std::string> posts;
+ absl::ReaderMutexLock l(&lock_);
+ if (inactive_session_ids_.contains(sid)) {
+ std::cerr << "Using an inactive session id!\n";
+ std::abort();
+ }
+ if (active_session_ids_.contains(sid)) {
+ posts.push_back("Random post");
+ }
+ return posts;
+}
+
+} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/testdata/mini_blogger_context.h b/rpc_fuzzing/testdata/mini_blogger_context.h
new file mode 100644
index 0000000..7bde0ac
--- /dev/null
+++ b/rpc_fuzzing/testdata/mini_blogger_context.h
@@ -0,0 +1,59 @@
+#ifndef THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_
+#define THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "./rpc_fuzzing/testdata/mini_blogger.pb.h"
+
+namespace fuzztest::internal {
+
+struct MiniBloggerUser {
+ std::string name;
+ std::string email;
+ std::string passwd;
+};
+
+// The real handling logics of MiniBlogger, shared by by gRPC and other
+// frameworks.
+class MiniBloggerContext {
+ public:
+ // Adds a new user into the database.
+ void RegisterUser(const RegisterUserRequest& request,
+ RegisterUserResponse& response);
+
+ // Logs in a user if the user info exists in the database and generates a
+ // session id.
+ void LogInUser(const LogInUserRequest& request, LogInUserResponse& response);
+
+ // Gets the posts for a user given the session id.
+ void GetUserPosts(const GetUserPostsRequest& request,
+ GetUserPostsResponse& response);
+
+ // Logs out a user given the session id. The session id will be invalidated.
+ void LogOutUser(const LogOutUserRequest& request,
+ LogOutUserResponse& response);
+
+ private:
+ bool RegisterUserHandler(absl::string_view name, absl::string_view email,
+ absl::string_view passwd);
+ std::optional<std::int64_t> LoginUserHandler(absl::string_view name,
+ absl::string_view passwd);
+ bool LogOutUserHandler(std::int64_t sid);
+ std::vector<std::string> GetUserPostsHandler(std::int64_t sid);
+
+ absl::flat_hash_map<std::string /*name*/, MiniBloggerUser> users_;
+ absl::flat_hash_set<std::int64_t> active_session_ids_;
+ absl::flat_hash_set<std::int64_t> inactive_session_ids_;
+ absl::Mutex lock_; // Serialize operation.
+};
+
+} // namespace fuzztest::internal
+
+#endif // THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_