No public description PiperOrigin-RevId: 554371233
diff --git a/.github/workflows/bazel_test.yml b/.github/workflows/bazel_test.yml index eca5cf6..580e546 100644 --- a/.github/workflows/bazel_test.yml +++ b/.github/workflows/bazel_test.yml
@@ -28,7 +28,7 @@ run_tests: name: Run tests runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: config: ['default', 'fuzztest'] @@ -62,10 +62,10 @@ if: matrix.config == 'default' run: | bazel test --build_tests_only --test_output=errors \ - -c ${{ matrix.compilation_mode }} -- //... -//centipede/... + -c ${{ matrix.compilation_mode }} --//rpc_fuzzing:enable_rpc_fuzzing -- //... -//centipede/... - name: Run end-to-end tests with --config=fuzztest if: matrix.config == 'fuzztest' run: | bazel test --build_tests_only --test_output=errors \ - -c ${{ matrix.compilation_mode }} --config=fuzztest //e2e_tests:all + -c ${{ matrix.compilation_mode }} --config=fuzztest //e2e_tests:all --//rpc_fuzzing:enable_rpc_fuzzing
diff --git a/.github/workflows/cmake_test.yml b/.github/workflows/cmake_test.yml index 55bc145..257b8ca 100644 --- a/.github/workflows/cmake_test.yml +++ b/.github/workflows/cmake_test.yml
@@ -26,7 +26,7 @@ run_tests: name: Run CMake tests runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: mode: ['default', 'fuzzing']
diff --git a/WORKSPACE b/WORKSPACE index b685cdc..d8a02a2 100644 --- a/WORKSPACE +++ b/WORKSPACE
@@ -36,9 +36,9 @@ http_archive( name = "com_google_absl", - sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21", - strip_prefix = "abseil-cpp-20230125.0", - url = "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz", + sha256 = "5366d7e7fa7ba0d915014d387b66d0d002c03236448e1ba9ef98122c13b35c36", + strip_prefix = "abseil-cpp-20230125.3", + url = "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.3.tar.gz", ) http_archive( @@ -55,6 +55,18 @@ url = "https://www.antlr.org/download/antlr4-cpp-runtime-4.12.0-source.zip", ) +http_archive( + name = "com_github_grpc_grpc", + urls = [ + "https://github.com/grpc/grpc/archive/49dcbce9aba5f5f0959871354df90e38f0c5ed00.tar.gz", # v1.55.3 + ], + strip_prefix = "grpc-49dcbce9aba5f5f0959871354df90e38f0c5ed00", +) +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +grpc_deps() +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") +grpc_extra_deps() + ################################################################################ # Direct dependencies that are only required for running tests ################################################################################
diff --git a/domain_tests/BUILD b/domain_tests/BUILD index 4e24aa6..28c201f 100644 --- a/domain_tests/BUILD +++ b/domain_tests/BUILD
@@ -22,7 +22,10 @@ name = "domain_testing", testonly = 1, hdrs = ["domain_testing.h"], - visibility = ["@com_google_fuzztest//fuzztest:__subpackages__"], + visibility = [ + "@com_google_fuzztest//fuzztest:__subpackages__", + "@com_google_fuzztest//rpc_fuzzing:__pkg__", + ], deps = [ "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash",
diff --git a/e2e_tests/BUILD b/e2e_tests/BUILD index b146dd3..3df931e 100644 --- a/e2e_tests/BUILD +++ b/e2e_tests/BUILD
@@ -26,6 +26,7 @@ srcs = ["functional_test.cc"], data = [ "@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_for_functional_testing.stripped", + "@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_for_rpc_session_grpc_functional_testing.stripped", "@com_google_fuzztest//e2e_tests/testdata:fuzz_tests_with_invalid_seeds.stripped", ], shard_count = 50,
diff --git a/e2e_tests/functional_test.cc b/e2e_tests/functional_test.cc index 070451e..e54fca8 100644 --- a/e2e_tests/functional_test.cc +++ b/e2e_tests/functional_test.cc
@@ -1273,6 +1273,28 @@ EXPECT_EQ(1, CountSubstrs(std_err, "<<FixtureTest::~FixtureTest()>>")); } +TEST_F(FuzzingModeTest, RpcSessionFuzzingSupportsFuzzingGrpcService) { + auto [status, std_out, std_err] = RunBinaryWith( + BinaryPath("testdata/fuzz_tests_for_rpc_session_grpc_functional_testing"), + "--fuzz=MiniBloggerGrpcTest." + "ServiceDoesNotCrashWithAnyRpcSequence", + {}, /*timeout=*/absl::Seconds(60)); + EXPECT_THAT(std_err, Not(HasSubstr("Failed to execute !"))); + EXPECT_THAT(std_err, HasSubstr("Using an inactive session id!")); + EXPECT_THAT(status, Eq(Signal(SIGABRT))); +} + +TEST_F(FuzzingModeTest, RpcSessionOfForGrpcFindsStatefulBugInFuzzingMode) { + auto [status, std_out, std_err] = RunBinaryWith( + BinaryPath("testdata/fuzz_tests_for_rpc_session_grpc_functional_testing"), + "--fuzz=MiniBloggerGrpcTest." + "TestRpcSessionOfSetup", + {{"FUZZTEST_MAX_FUZZING_RUNS", "-1"}}, /*timeout=*/absl::Seconds(60)); + EXPECT_THAT(std_err, Not(HasSubstr("Failed to execute !"))); + EXPECT_THAT(std_err, HasSubstr("Using an inactive session id!")); + EXPECT_THAT(status, Eq(Signal(SIGABRT))); +} + TEST_F(FuzzingModeTest, GoogleTestPerIterationFixtureInstantiatedOncePerIteration) { auto [status, std_out, std_err] = RunWith(
diff --git a/e2e_tests/testdata/BUILD b/e2e_tests/testdata/BUILD index cd3896e..6cc5059 100644 --- a/e2e_tests/testdata/BUILD +++ b/e2e_tests/testdata/BUILD
@@ -59,6 +59,26 @@ ], ) +cc_binary( + name = "fuzz_tests_for_rpc_session_grpc_functional_testing", + testonly = 1, + srcs = [ + "fuzz_tests_for_rpc_session_grpc_functional_testing.cc", + ], + deps = [ + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest", + "@com_google_fuzztest//fuzztest:fuzztest_gtest_main", + "@com_google_fuzztest//rpc_fuzzing:grpc_stub", + "@com_google_fuzztest//rpc_fuzzing:rpc_session", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_grpc_service", + ], +) + # Must be a separate binary, as the invalid seeds make the binary crash at # process start time. Used by `functional_test` only. cc_binary(
diff --git a/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc b/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc new file mode 100644 index 0000000..fc4f295 --- /dev/null +++ b/e2e_tests/testdata/fuzz_tests_for_rpc_session_grpc_functional_testing.cc
@@ -0,0 +1,76 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Rpc session fuzz test examples to be used for e2e functional testing. +// +// Specifically, used by `functional_test` only. + +#include <cstdlib> +#include <iostream> +#include <memory> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "./fuzztest/fuzztest.h" +#include "./rpc_fuzzing/grpc_stub.h" +#include "./rpc_fuzzing/rpc_session.h" +#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h" +#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h" +#include "grpcpp//security/server_credentials.h" +#include "grpcpp//server.h" +#include "grpcpp//server_builder.h" +#include "grpcpp//support/channel_arguments.h" + +namespace { + +std::string_view ServiceNameFactory() { + return "fuzztest.internal.MiniBlogger"; +} + +class MiniBloggerGrpcTest { + public: + MiniBloggerGrpcTest() + : server_(grpc::ServerBuilder() + .RegisterService(&mini_blogger_service_) + .BuildAndStart()), + stub_(server_->InProcessChannel(grpc::ChannelArguments())) {} + + ~MiniBloggerGrpcTest() { server_->Shutdown(); } + + void ServiceDoesNotCrashWithAnyRpcSequence( + fuzztest::RpcSequence rpc_sequence) { + absl::Status status = fuzztest::ExecuteRpcSequence(stub_, rpc_sequence); + if (!status.ok()) { + std::cerr << "Failed to execute !" << status.message() << "\n"; + std::abort(); + } + } + + void TestRpcSessionOfSetup(fuzztest::RpcSequence rpc_sequence) { + ServiceDoesNotCrashWithAnyRpcSequence(rpc_sequence); + } + + private: + fuzztest::internal::MiniBloggerGrpcService mini_blogger_service_; + std::unique_ptr<grpc::Server> server_ = nullptr; + fuzztest::GrpcStub stub_; +}; + +FUZZ_TEST_F(MiniBloggerGrpcTest, ServiceDoesNotCrashWithAnyRpcSequence) + .WithDomains(fuzztest::RpcSession<fuzztest::internal::MiniBlogger>()); + +FUZZ_TEST_F(MiniBloggerGrpcTest, TestRpcSessionOfSetup) + .WithDomains(fuzztest::RpcSessionOf(ServiceNameFactory)); + +} // namespace
diff --git a/fuzztest/BUILD b/fuzztest/BUILD index 041e9cf..e75c0c9 100644 --- a/fuzztest/BUILD +++ b/fuzztest/BUILD
@@ -128,6 +128,7 @@ visibility = [ "@com_google_fuzztest//centipede:__subpackages__", "@com_google_fuzztest//domain_tests:__pkg__", + "@com_google_fuzztest//rpc_fuzzing:__pkg__", ], deps = [ ":absl_helpers",
diff --git a/rpc_fuzzing/BUILD b/rpc_fuzzing/BUILD new file mode 100644 index 0000000..1940b24 --- /dev/null +++ b/rpc_fuzzing/BUILD
@@ -0,0 +1,221 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + "@com_google_fuzztest//e2e_tests/testdata:__pkg__", + "@com_google_fuzztest//fuzztest:__subpackages__", +]) + +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +bool_flag( + name = "enable_rpc_fuzzing", + build_setting_default = False, + visibility = ["//visibility:public"], +) + +config_setting( + name = "rpc_fuzzing_is_enabled", + flag_values = { + ":enable_rpc_fuzzing": "true", + }, +) + +cc_library( + name = "proto_field_path", + srcs = ["proto_field_path.cc"], + hdrs = ["proto_field_path.h"], + target_compatible_with = select({ + ":rpc_fuzzing_is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest:logging", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_field_path_test", + srcs = ["proto_field_path_test.cc"], + deps = [ + ":proto_field_path", + "@com_google_fuzztest//fuzztest:test_protobuf_cc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "rpc_sequence", + srcs = ["rpc_sequence.cc"], + hdrs = ["rpc_sequence.h"], + deps = [ + ":proto_field_path", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_fuzztest//fuzztest:logging", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "rpc_sequence_test", + srcs = ["rpc_sequence_test.cc"], + deps = [ + ":proto_field_path", + ":rpc_sequence", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +# TODO(changochen): Combine the stubs into a single lib. +cc_library( + name = "rpc_stub", + hdrs = ["rpc_stub.h"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "grpc_stub", + srcs = ["grpc_stub.cc"], + hdrs = ["grpc_stub.h"], + deps = [ + ":rpc_stub", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_codegen_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "rpc_executor", + srcs = ["rpc_executor.cc"], + hdrs = ["rpc_executor.h"], + deps = [ + ":rpc_sequence", + ":rpc_stub", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest:logging", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "rpc_executor_test", + srcs = ["rpc_executor_test.cc"], + deps = [ + ":grpc_stub", + ":proto_field_path", + ":rpc_executor", + ":rpc_sequence", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//domain_tests:domain_testing", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_grpc_service", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "rpc_potential_dfg", + srcs = ["rpc_potential_dfg.cc"], + hdrs = ["rpc_potential_dfg.h"], + deps = [ + ":proto_field_path", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest:logging", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "rpc_potential_dfg_test", + srcs = ["rpc_potential_dfg_test.cc"], + deps = [ + ":proto_field_path", + ":rpc_potential_dfg", + "@com_github_grpc_grpc//:grpc++", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "rpc_session", + hdrs = ["rpc_session.h"], + deps = [ + ":proto_field_path", + ":rpc_executor", + ":rpc_potential_dfg", + ":rpc_sequence", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest:domain", + "@com_google_fuzztest//fuzztest:domain_core", + "@com_google_fuzztest//fuzztest:logging", + "@com_google_fuzztest//fuzztest:serialization", + "@com_google_fuzztest//fuzztest:table_of_recent_compares", + "@com_google_fuzztest//fuzztest:type_support", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "rpc_session_test", + srcs = ["rpc_session_test.cc"], + deps = [ + ":proto_field_path", + ":rpc_potential_dfg", + ":rpc_sequence", + ":rpc_session", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/random", + "@com_google_fuzztest//domain_tests:domain_testing", + "@com_google_fuzztest//fuzztest:serialization", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_grpc_proto", + "@com_google_fuzztest//rpc_fuzzing/testdata:mini_blogger_cc_proto", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +)
diff --git a/rpc_fuzzing/grpc_stub.cc b/rpc_fuzzing/grpc_stub.cc new file mode 100644 index 0000000..fa4a954 --- /dev/null +++ b/rpc_fuzzing/grpc_stub.cc
@@ -0,0 +1,57 @@ +#include "./rpc_fuzzing/grpc_stub.h" + +#include <memory> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/notification.h" +#include "grpcpp//impl/proto_utils.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "grpcpp//client_context.h" +#include "grpcpp//generic/generic_stub.h" +#include "grpcpp//support/status.h" +#include "grpcpp//support/stub_options.h" + +namespace fuzztest { + +absl::StatusOr<std::unique_ptr<google::protobuf::Message>> GrpcStub::CallMethod( + const google::protobuf::MethodDescriptor& method_descriptor, + const google::protobuf::Message& request) { + grpc::ClientContext cli_ctx; + const google::protobuf::Message* response_prototype = + google::protobuf::MessageFactory::generated_factory()->GetPrototype( + method_descriptor.output_type()); + if (!response_prototype) { + return absl::InternalError( + absl::StrCat("Cannot find prototype for ", + method_descriptor.output_type()->full_name(), + " in the generated proto MessageFactory")); + } + std::unique_ptr<google::protobuf::Message> response = + absl::WrapUnique(response_prototype->New()); + absl::Notification notification; + absl::Status status = absl::OkStatus(); + grpc_stub_->UnaryCall( + &cli_ctx, + absl::StrCat("/", method_descriptor.service()->full_name(), "/", + method_descriptor.name()), + grpc::StubOptions(), &request, response.get(), + [¬ification, &status](grpc::Status s) { + // TODO(changochen): Better conversion from grpc::Status to absl::Status. + if (!s.ok()) { + status = absl::InternalError(s.error_message()); + } + notification.Notify(); + }); + notification.WaitForNotification(); + if (status.ok()) { + return response; + } else { + return status; + } +} + +} // namespace fuzztest
diff --git a/rpc_fuzzing/grpc_stub.h b/rpc_fuzzing/grpc_stub.h new file mode 100644 index 0000000..1e046b0 --- /dev/null +++ b/rpc_fuzzing/grpc_stub.h
@@ -0,0 +1,32 @@ +#ifndef FUZZTEST_RPC_FUZZING_GRPC_STUB_H_ +#define FUZZTEST_RPC_FUZZING_GRPC_STUB_H_ +#include <memory> + +#include "google/protobuf/message.h" +#include "absl/status/statusor.h" +#include "./rpc_fuzzing/rpc_stub.h" +#include "grpcpp//generic/generic_stub.h" + +namespace fuzztest { + +using GrpcGenericStub = + grpc::TemplatedGenericStub<google::protobuf::Message, google::protobuf::Message>; + +class GrpcStub : public RpcStub { + public: + GrpcStub(std::shared_ptr<grpc::ChannelInterface> channel) + : grpc_stub_( + std::make_unique< + grpc::TemplatedGenericStub<google::protobuf::Message, google::protobuf::Message>>( + channel)) {} + absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod( + const google::protobuf::MethodDescriptor& method_descriptor, + const google::protobuf::Message& request) override; + + private: + std::unique_ptr<GrpcGenericStub> grpc_stub_ = nullptr; +}; + +} // namespace fuzztest + +#endif // FUZZTEST_RPC_FUZZING_GRPC_STUB_H_
diff --git a/rpc_fuzzing/proto_field_path.cc b/rpc_fuzzing/proto_field_path.cc new file mode 100644 index 0000000..b1d7b37 --- /dev/null +++ b/rpc_fuzzing/proto_field_path.cc
@@ -0,0 +1,313 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "./rpc_fuzzing/proto_field_path.h" + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "./fuzztest/internal/logging.h" + +namespace fuzztest::internal { + +bool Contains(const google::protobuf::FieldDescriptor& parent, + const google::protobuf::FieldDescriptor& child) { + return parent.message_type() == child.containing_type(); +} + +bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.field_descriptors_ == rhs.field_descriptors_; +} + +bool FieldPath::CanAppend(const google::protobuf::FieldDescriptor& field) const { + return field_descriptors_.empty() || + Contains(*field_descriptors_.back(), field); +} + +void FieldPath::AppendField(const google::protobuf::FieldDescriptor& field) { + FUZZTEST_INTERNAL_CHECK( + CanAppend(field), + "The current last field in the path must be a message containing the " + "appended field."); + field_descriptors_.push_back(&field); +} + +void FieldPath::AppendPath(const FieldPath& other) { + if (other.field_descriptors_.empty()) { + return; + } + field_descriptors_.reserve(field_descriptors_.size() + + other.field_descriptors_.size()); + FUZZTEST_INTERNAL_CHECK( + CanAppend(*other.GetAllFields()[0]), + "The current last field in the path must be a message containing the " + "first field in the other path."); + field_descriptors_.insert(field_descriptors_.end(), + other.field_descriptors_.begin(), + other.field_descriptors_.end()); +} + +std::string FieldPath::ToString() const { + return absl::StrJoin( + field_descriptors_, /*separator=*/".", + [](std::string* result, const google::protobuf::FieldDescriptor* field) { + absl::StrAppend(result, field->name()); + }); +} + +FieldPath GetFieldPathWithDescriptor(const google::protobuf::Descriptor& descriptor, + std::string_view field_path_str) { + std::vector<std::string> parts = absl::StrSplit(field_path_str, '.'); + FUZZTEST_INTERNAL_CHECK(!parts.empty(), "Invalid path string!"); + + FieldPath result; + const google::protobuf::Descriptor* descriptor_ptr = &descriptor; + + for (const std::string& part : parts) { + const google::protobuf::FieldDescriptor* field = + descriptor_ptr->FindFieldByName(part); + FUZZTEST_INTERNAL_CHECK(field != nullptr, "Invalid field name!"); + result.AppendField(*field); + descriptor_ptr = field->message_type(); + } + return result; +} + +const google::protobuf::Message* FieldPath::GetContainingMessageOfLastField( + const google::protobuf::Message& message) const { + FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(), + "Empty field path!"); + const google::protobuf::Message* parent = &message; + for (size_t i = 0; i < field_descriptors_.size() - 1; ++i) { + if (field_descriptors_[i]->is_repeated()) { + if (parent->GetReflection()->FieldSize(*parent, field_descriptors_[i]) == + 0) { + return nullptr; + } + parent = &(parent->GetReflection()->GetRepeatedMessage( + *parent, field_descriptors_[i], 0)); + } else { + parent = &( + parent->GetReflection()->GetMessage(*parent, field_descriptors_[i])); + } + } + return parent; +} + +google::protobuf::Message* FieldPath::MutableContainingMessageOfLastField( + google::protobuf::Message& message) const { + FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(), + "Empty field path!"); + google::protobuf::Message* parent = &message; + for (size_t i = 0; i < field_descriptors_.size() - 1; ++i) { + if (field_descriptors_[i]->is_repeated()) { + if (parent->GetReflection()->FieldSize(*parent, field_descriptors_[i]) == + 0) { + parent->GetReflection()->AddMessage(parent, field_descriptors_[i]); + } + parent = parent->GetReflection()->MutableRepeatedMessage( + parent, field_descriptors_[i], 0); + } else { + parent = parent->GetReflection()->MutableMessage(parent, + field_descriptors_[i]); + } + } + return parent; +} + +const google::protobuf::FieldDescriptor& FieldPath::GetLastField() const { + FUZZTEST_INTERNAL_CHECK_PRECONDITION(!field_descriptors_.empty(), + "Empty field path!"); + return *field_descriptors_.back(); +} + +const std::vector<const google::protobuf::FieldDescriptor*>& FieldPath::GetAllFields() + const { + return field_descriptors_; +} + +void CopyField(const FieldPath& from_field, const google::protobuf::Message& from, + const FieldPath& to_field, google::protobuf::Message& to) { + const google::protobuf::Message* from_inner_most_message = + from_field.GetContainingMessageOfLastField(from); + if (from_inner_most_message == nullptr) { + return; + } + const google::protobuf::FieldDescriptor& from_last_field = from_field.GetLastField(); + google::protobuf::Message* to_inner_most_message = + to_field.MutableContainingMessageOfLastField(to); + if (to_inner_most_message == nullptr) { + return; + } + const google::protobuf::FieldDescriptor& to_last_field = to_field.GetLastField(); + + const google::protobuf::Reflection* from_refl = + from_inner_most_message->GetReflection(); + const google::protobuf::Reflection* to_refl = to_inner_most_message->GetReflection(); + FUZZTEST_INTERNAL_CHECK(from_last_field.type() == to_last_field.type(), + "Fields of mismatch types cannot be copied!"); + // TODO(changochen): We might make this condition optional. + FUZZTEST_INTERNAL_CHECK(from_last_field.name() == to_last_field.name(), + "Fields of mismatch names cannot be copied!"); + switch (from_last_field.type()) { +#define HANDLE_TYPE(UPPERCASE, CAMEL) \ + case google::protobuf::FieldDescriptor::TYPE_##UPPERCASE: \ + if (from_last_field.is_repeated() && to_last_field.is_repeated()) { \ + to_refl->ClearField(to_inner_most_message, &to_last_field); \ + for (int i = 0; i < from_refl->FieldSize(*from_inner_most_message, \ + &from_last_field); \ + ++i) { \ + to_refl->Add##CAMEL( \ + to_inner_most_message, &to_last_field, \ + from_refl->GetRepeated##CAMEL(*from_inner_most_message, \ + &from_last_field, i)); \ + } \ + } else if (from_last_field.is_repeated()) { \ + if (from_refl->FieldSize(*from_inner_most_message, &from_last_field) == \ + 0) { \ + to_refl->ClearField(to_inner_most_message, &to_last_field); \ + } else { \ + to_refl->Set##CAMEL( \ + to_inner_most_message, &to_last_field, \ + from_refl->GetRepeated##CAMEL(*from_inner_most_message, \ + &from_last_field, 0)); \ + } \ + } else if (to_last_field.is_repeated()) { \ + to_refl->ClearField(to_inner_most_message, &to_last_field); \ + to_refl->Add##CAMEL( \ + to_inner_most_message, &to_last_field, \ + from_refl->Get##CAMEL(*from_inner_most_message, &from_last_field)); \ + } else { \ + to_refl->Set##CAMEL( \ + to_inner_most_message, &to_last_field, \ + from_refl->Get##CAMEL(*from_inner_most_message, &from_last_field)); \ + } \ + break; + + HANDLE_TYPE(DOUBLE, Double); + HANDLE_TYPE(FLOAT, Float); + HANDLE_TYPE(INT64, Int64); + HANDLE_TYPE(UINT64, UInt64); + HANDLE_TYPE(INT32, Int32); + HANDLE_TYPE(FIXED64, UInt64); + HANDLE_TYPE(FIXED32, UInt32); + HANDLE_TYPE(BOOL, Bool); + HANDLE_TYPE(STRING, String); + HANDLE_TYPE(BYTES, String); + HANDLE_TYPE(UINT32, UInt32); + HANDLE_TYPE(ENUM, Enum); + HANDLE_TYPE(SFIXED64, Int64); + HANDLE_TYPE(SFIXED32, Int32); + HANDLE_TYPE(SINT64, Int64); + HANDLE_TYPE(SINT32, Int32); + +#undef HANDLE_TYPE + case google::protobuf::FieldDescriptor::TYPE_GROUP: + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + if (from_last_field.is_repeated() && to_last_field.is_repeated()) { + to_refl->ClearField(to_inner_most_message, &to_last_field); + for (int i = 0; i < from_refl->FieldSize(*from_inner_most_message, + &from_last_field); + ++i) { + to_refl->AddMessage(to_inner_most_message, &to_last_field) + ->CopyFrom(from_refl->GetRepeatedMessage(*from_inner_most_message, + &from_last_field, i)); + } + } else if (from_last_field.is_repeated()) { + if (from_refl->FieldSize(*from_inner_most_message, &from_last_field) == + 0) { + to_refl->ClearField(to_inner_most_message, &to_last_field); + } else { + to_refl->MutableMessage(to_inner_most_message, &to_last_field) + ->CopyFrom(from_refl->GetRepeatedMessage(*from_inner_most_message, + &from_last_field, 0)); + } + } else if (to_last_field.is_repeated()) { + to_refl->ClearField(to_inner_most_message, &to_last_field); + to_refl->AddMessage(to_inner_most_message, &to_last_field) + ->CopyFrom(from_refl->GetMessage(*from_inner_most_message, + &from_last_field)); + } else { + to_refl->MutableMessage(to_inner_most_message, &to_last_field) + ->CopyFrom(from_refl->GetMessage(*from_inner_most_message, + &from_last_field)); + } + break; + default: + FUZZTEST_INTERNAL_CHECK( + false, absl::StrCat("Unexpected type ", from_last_field.type_name())); + } +} + +std::vector<FieldPath> CollectAllFieldsImpl( + const google::protobuf::Descriptor& message_descriptor, + absl::flat_hash_set<const google::protobuf::Descriptor*>& visited_messages) { + std::vector<FieldPath> results; + if (visited_messages.contains(&message_descriptor)) { + return results; + } + visited_messages.insert(&message_descriptor); + for (size_t i = 0; i < message_descriptor.field_count(); ++i) { + const google::protobuf::FieldDescriptor& field = *message_descriptor.field(i); + FieldPath field_path; + field_path.AppendField(field); + results.push_back(field_path); + // `GROUP` field is a deprecated way of expressing inner message. + if (field.type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE || + field.type() == google::protobuf::FieldDescriptor::TYPE_GROUP) { + const google::protobuf::Descriptor& inner = *field.message_type(); + std::vector<FieldPath> inner_fields = + CollectAllFieldsImpl(inner, visited_messages); + for (const auto& loc : inner_fields) { + FieldPath inner_field_path(field_path); + inner_field_path.AppendPath(loc); + results.push_back(inner_field_path); + } + } + } + return results; +} + +std::vector<FieldPath> CollectAllFields( + const google::protobuf::Descriptor& message_descriptor) { + absl::flat_hash_set<const google::protobuf::Descriptor*> visited_messages; + return CollectAllFieldsImpl(message_descriptor, visited_messages); +} + +bool AreDifferentFieldsInSameOneOf(const google::protobuf::FieldDescriptor& a, + const google::protobuf::FieldDescriptor& b) { + if (&a == &b || a.containing_oneof() == nullptr || + b.containing_oneof() == nullptr) + return false; + return a.containing_oneof() == b.containing_oneof(); +} + +bool AreOneOfAltearnatives(const FieldPath& a, const FieldPath& b) { + const std::vector<const google::protobuf::FieldDescriptor*>& a_fields = + a.GetAllFields(); + const std::vector<const google::protobuf::FieldDescriptor*>& b_fields = + b.GetAllFields(); + for (size_t i = 0; i < std::min(a_fields.size(), b_fields.size()); ++i) { + const google::protobuf::FieldDescriptor& a_field = *a_fields[i]; + const google::protobuf::FieldDescriptor& b_field = *b_fields[i]; + if (AreDifferentFieldsInSameOneOf(a_field, b_field)) { + return true; + } + } + return false; +} + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/proto_field_path.h b/rpc_fuzzing/proto_field_path.h new file mode 100644 index 0000000..1042fe6 --- /dev/null +++ b/rpc_fuzzing/proto_field_path.h
@@ -0,0 +1,120 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_ +#define FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_ + +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace fuzztest::internal { + +// Represent a path to a (sub)field in a message. +class FieldPath { + public: + // Append a field as a part to the path. + void AppendField(const google::protobuf::FieldDescriptor& field); + // Append another path to the path. + void AppendPath(const FieldPath& other); + + // Given a message, follow the field path and return the parent message of the + // last field. If any repeated field in the middle is empty, return nullptr + const google::protobuf::Message* GetContainingMessageOfLastField( + const google::protobuf::Message& message) const; + // The mutable version fills in empty fields along the path, and thus never + // returns a null pointer. + google::protobuf::Message* MutableContainingMessageOfLastField( + google::protobuf::Message& message) const; + + // Get the last field in the path. + const google::protobuf::FieldDescriptor& GetLastField() const; + // Get all the parts of the path. + const std::vector<const google::protobuf::FieldDescriptor*>& GetAllFields() const; + + // Returns a string concatenating all the field parts by ".". For example, a + // field path with two fields "field1" and "field2" will return + // "field1.field2". + std::string ToString() const; + + private: + bool CanAppend(const google::protobuf::FieldDescriptor& field) const; + + // Make FieldPath hashable. + template <typename H> + friend H AbslHashValue(H h, const FieldPath& m) { + h = H::combine_contiguous(std::move(h), m.field_descriptors_.data(), + m.field_descriptors_.size()); + return h; + } + friend bool operator==(const FieldPath& lhs, const FieldPath& rhs); + + // A list of fields that allow us to locate inner fields. + std::vector<const google::protobuf::FieldDescriptor*> field_descriptors_; +}; + +bool operator==(const FieldPath& lhs, const FieldPath& rhs); + +// Utility function for copying fields in different proto messages, the fields +// should have the same names and types, otherwise an assertion will be +// triggered. If the source field is repeated but the sink is not, we copy the +// first element of the source. If the source field is not repeated but the sink +// is, we clear the sink and copy the source as the first element of the sink. +// TODO(changochen): Make the condition of same name optional. +// TODO(changochen): Select random element of a repeated fields for copying. +void CopyField(const FieldPath& from_field, const google::protobuf::Message& from, + const FieldPath& to_field, google::protobuf::Message& to); + +FieldPath GetFieldPathWithDescriptor(const google::protobuf::Descriptor& descriptor, + std::string_view field_path_str); + +// Get a field path in a message from a path string. The path string should be +// in format like "field1.subfield2.subfield3.xxx". +// TODO(changochen): Return absl::StatusOr<FieldPath> instead of aborting. Make +// this a constructor of field path. +template <typename MessageT> +FieldPath GetFieldPath(std::string_view field_path_str) { + const google::protobuf::Descriptor* descriptor = MessageT::GetDescriptor(); + return GetFieldPathWithDescriptor(*descriptor, field_path_str); +} + +// Utility function for collecting all (sub)fields in message type. If the +// message type is recursive, then we only collect fields in a message when we +// see its type for the first time. +std::vector<FieldPath> CollectAllFields( + const google::protobuf::Descriptor& message_descriptor); + +// Checks whether two fields path are alternatives to each other (i.e., within +// the same oneof group). If yes, only at most one of them should be set at any +// given time. +// For example, the following message: message Message { +// oneof test_one_of { +// Message1 a = 1; +// int b = 2; +// } +// } +// +// message Message1 { +// int a1 = 1; +// int a2 = 2; +// } +// +// We consider that a.a1 and b are within the same oneof, therefore +// "alternatives", while a.a1 and a.a2 are not. +bool AreOneOfAltearnatives(const FieldPath& a, const FieldPath& b); + +} // namespace fuzztest::internal + +#endif // FUZZTEST_RPC_FUZZING_PROTO_FIELD_PATH_H_
diff --git a/rpc_fuzzing/proto_field_path_test.cc b/rpc_fuzzing/proto_field_path_test.cc new file mode 100644 index 0000000..c00801a --- /dev/null +++ b/rpc_fuzzing/proto_field_path_test.cc
@@ -0,0 +1,514 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "./rpc_fuzzing/proto_field_path.h" + +#include <cstdint> +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "./fuzztest/internal/test_protobuf.pb.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" + +namespace fuzztest::internal { + +using ::testing::ElementsAre; +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::UnorderedElementsAre; + +namespace { + +TEST(FieldPathTest, GetFieldPathParseStringIntoAFieldPath) { + const FieldPath field_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + EXPECT_THAT( + field_path.GetAllFields(), + ElementsAre( + LogOutUserRequest::GetDescriptor()->FindFieldByName("log_out_info"), + LogOutInfo::GetDescriptor()->FindFieldByName("session_info"), + SessionInfo::GetDescriptor()->FindFieldByName("session_id"))); +} + +TEST(FieldPathTest, GetFieldPathAbortsAtInvalidField) { + EXPECT_DEATH_IF_SUPPORTED( + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.invalid"), + "Invalid field name!"); +} + +MATCHER_P(FieldPathAsString, path_str, "") { + if (arg.GetAllFields().empty()) return path_str == ""; + std::string result; + for (const google::protobuf::FieldDescriptor* part : arg.GetAllFields()) { + absl::StrAppend(&result, part->name(), "."); + } + return result.substr(0, result.size() - 1) == path_str; +} + +TEST(FieldPathTest, HashAndEqualityCheck) { + const FieldPath path1 = GetFieldPath<LogInUserResponse>("session_id"); + const FieldPath path2 = GetFieldPath<LogInUserResponse>("session_id"); + EXPECT_EQ(path1, path2); +} + +TEST(FieldPathTest, CopyFieldOfSameTypeAndSameNameInMessagesSucceeds) { + // Test copying LogInUserResponse.session_id to + // LogOutUserRequest.log_out_info.session_info.session_id. + LogInUserResponse log_in_response; + const FieldPath session_id_source_path = + GetFieldPath<LogInUserResponse>("session_id"); + LogOutUserRequest log_out_request; + const FieldPath session_id_sink_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + + constexpr int64_t kSessionId = 0xdeadbeef; + log_in_response.set_session_id(kSessionId); + CopyField(session_id_source_path, log_in_response, session_id_sink_path, + log_out_request); + + EXPECT_EQ(kSessionId, + log_out_request.log_out_info().session_info().session_id()); +} + +TEST(FieldPathTest, CopyingRepeatedFieldToNonRepeatedCopiesTheFirstElement) { + // Copy repeated `{val1, val2}` to non repeated `val2` results in `val2`. + constexpr int64_t kValue1 = 123; + constexpr int64_t kValue2 = 234; + MessageContainingRepeatedFields message_containing_repeated_fields; + message_containing_repeated_fields.add_field(kValue1); + message_containing_repeated_fields.add_field(kValue2); + const FieldPath source_path = + GetFieldPath<MessageContainingRepeatedFields>("field"); + MessageNotContainingRepeatedFields message_not_containing_repeated_fields; + message_not_containing_repeated_fields.set_field(kValue2); + const FieldPath sink_path = + GetFieldPath<MessageNotContainingRepeatedFields>("field"); + + CopyField(source_path, message_containing_repeated_fields, sink_path, + message_not_containing_repeated_fields); + + MessageNotContainingRepeatedFields expected_result; + expected_result.set_field(kValue1); + EXPECT_EQ(message_not_containing_repeated_fields.DebugString(), + expected_result.DebugString()); +} + +TEST(FieldPathTest, + CopyFieldDoesNothingIfSourceContainsEmptyRepeatedFieldInTheMiddle) { + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath source_path = GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + MessageNotContainingRepeatedFields message_not_containing_repeated_fields; + message_not_containing_repeated_fields.set_field2(123); + const MessageNotContainingRepeatedFields expected_result = + message_not_containing_repeated_fields; + + const FieldPath sink_path = + GetFieldPath<MessageNotContainingRepeatedFields>("field2"); + + CopyField(source_path, message_containing_repeated_fields, sink_path, + message_not_containing_repeated_fields); + EXPECT_EQ(message_not_containing_repeated_fields.DebugString(), + expected_result.DebugString()); +} + +TEST(FieldPathTest, + CopyFieldSucceedsIfSinkContainEmptyRepeatedFieldInTheMiddle) { + constexpr int64_t kValue1 = 123; + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath sink_path = GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + MessageNotContainingRepeatedFields message_not_containing_repeated_fields; + message_not_containing_repeated_fields.set_field2(kValue1); + + const FieldPath source_path = + GetFieldPath<MessageNotContainingRepeatedFields>("field2"); + + CopyField(source_path, message_not_containing_repeated_fields, sink_path, + message_containing_repeated_fields); + MessageContainingRepeatedFields expected_result; + expected_result.add_repeated_msg_field()->set_field2(kValue1); + EXPECT_EQ(message_containing_repeated_fields.DebugString(), + expected_result.DebugString()); +} + +TEST(FieldPathTest, CopyingEmptyRepeatedFieldToNonRepeatedClearsTheField) { + // Copy empty repeated `{}` to non repeated `val2` results in cleared field. + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath source_path = + GetFieldPath<MessageContainingRepeatedFields>("field"); + MessageNotContainingRepeatedFields message_not_containing_repeated_fields; + message_not_containing_repeated_fields.set_field(123); + const FieldPath sink_path = + GetFieldPath<MessageNotContainingRepeatedFields>("field"); + + CopyField(source_path, message_containing_repeated_fields, sink_path, + message_not_containing_repeated_fields); + + EXPECT_EQ(message_not_containing_repeated_fields.DebugString(), + MessageNotContainingRepeatedFields{}.DebugString()); +} + +TEST(FieldPathTest, CopyingNonRepeatedFieldToRepeatedFieldSetsFirstElement) { + // Copy non repeated `val2` to repeated `{val1, val2}` results in `{val2}`. + constexpr int64_t kValue1 = 123; + constexpr int64_t kValue2 = 234; + MessageNotContainingRepeatedFields message_not_containing_repeated_fields; + message_not_containing_repeated_fields.set_field(kValue2); + const FieldPath source_path = + GetFieldPath<MessageNotContainingRepeatedFields>("field"); + + MessageContainingRepeatedFields message_containing_repeated_fields; + message_containing_repeated_fields.add_field(kValue1); + message_containing_repeated_fields.add_field(kValue2); + const FieldPath sink_path = + GetFieldPath<MessageContainingRepeatedFields>("field"); + + CopyField(source_path, message_not_containing_repeated_fields, sink_path, + message_containing_repeated_fields); + + MessageContainingRepeatedFields expected_result; + expected_result.add_field(kValue2); + EXPECT_EQ(message_containing_repeated_fields.DebugString(), + expected_result.DebugString()); +} + +TEST(FieldPathTest, CopyingRepeatedFieldHandlesFieldOfMessageTypeCorrectly) { + MessageContainingRepeatedFields source_message, sink_message; + source_message.add_repeated_msg_field()->set_field2(0x123); + source_message.add_repeated_msg_field()->set_field2(0x234); + const FieldPath field_path = + GetFieldPath<MessageContainingRepeatedFields>("repeated_msg_field"); + + CopyField(field_path, source_message, field_path, sink_message); + + EXPECT_EQ(sink_message.DebugString(), source_message.DebugString()); +} + +TEST(FieldPathTest, + CopyingSIngularSrcToRepeatedDstHandlesFieldOfMessageTypeCorrectly) { + MessageWithSingleInnerRepeated source_message; + source_message.mutable_repeated_msg_field()->set_field2(123); + MessageContainingRepeatedFields sink_message; + sink_message.add_repeated_msg_field()->set_field2(1); + sink_message.add_repeated_msg_field()->set_field2(1); + const FieldPath sink_field_path = + GetFieldPath<MessageContainingRepeatedFields>("repeated_msg_field"); + const FieldPath source_field_path = + GetFieldPath<MessageWithSingleInnerRepeated>("repeated_msg_field"); + + CopyField(source_field_path, source_message, sink_field_path, sink_message); + + MessageContainingRepeatedFields expected_result; + expected_result.add_repeated_msg_field()->set_field2(123); + + EXPECT_EQ(sink_message.DebugString(), expected_result.DebugString()); +} + +TEST(FieldPathTest, CopyingRepeatedFieldToRepeatedFieldCopiesTheWholeVector) { + // Copy repeated `{val1, val2}` to repeated `{}` results in `{val1, val2}`. + constexpr int64_t kValue1 = 123; + constexpr int64_t kValue2 = 234; + + MessageContainingRepeatedFields source, sink; + source.add_field(kValue1); + source.add_field(kValue2); + + const FieldPath source_path = + GetFieldPath<MessageContainingRepeatedFields>("field"); + const FieldPath sink_path = + GetFieldPath<MessageContainingRepeatedFields>("field"); + + CopyField(source_path, source, sink_path, sink); + + MessageContainingRepeatedFields expected_result; + expected_result.add_field(kValue1); + expected_result.add_field(kValue2); + EXPECT_EQ(sink.DebugString(), expected_result.DebugString()); +} + +TEST(FieldPathTest, CopyFieldOfSameTypeAndDifferentNameInMessagesFails) { + // Test copying GetUserPostsOptions.order to + // GetUserPostsRequest.options. + GetUserPostsOptions options; + const FieldPath order_source_path = + GetFieldPath<GetUserPostsOptions>("order"); + GetUserPostsOptionsWithDifferentFieldNames options_with_different_field_names; + const FieldPath ordering_sink_path = + GetFieldPath<GetUserPostsOptionsWithDifferentFieldNames>("ordering"); + + EXPECT_DEATH_IF_SUPPORTED( + CopyField(order_source_path, options, ordering_sink_path, + options_with_different_field_names), + "Fields of mismatch names cannot be copied!"); +} + +TEST(FieldPathTest, CopyFieldOfDifferentTypeInMessagesFails) { + // Test copying LogInUserResponse.session_id to + // GetUserPostsRequest.options. + const LogInUserResponse log_in_response; + const FieldPath session_id_source_path = + GetFieldPath<LogInUserResponse>("session_id"); + GetUserPostsRequest get_user_posts_request; + const FieldPath session_id_sink_path = + GetFieldPath<GetUserPostsRequest>("options"); + + EXPECT_DEATH_IF_SUPPORTED( + CopyField(session_id_source_path, log_in_response, session_id_sink_path, + get_user_posts_request), + "Fields of mismatch types cannot be copied!"); +} + +TEST(FieldPathTest, CopyFieldHandlesEnumCorrectly) { + GetUserPostsOptions source_options, sink_options; + const FieldPath field_path = GetFieldPath<GetUserPostsOptions>("order"); + source_options.set_order(SortingOrder::ASCENDING); + + CopyField(field_path, source_options, field_path, sink_options); + + EXPECT_EQ(sink_options.order(), SortingOrder::ASCENDING); +} + +TEST(FieldPathTest, GetContainingMessageOfLastFieldReturnsContainingMessage) { + const LogOutUserRequest log_out_request; + const FieldPath session_id_sink_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + + const LogOutInfo& log_out_info = log_out_request.log_out_info(); + const SessionInfo& session_info = log_out_info.session_info(); + + const google::protobuf::Message* parent_msg = + session_id_sink_path.GetContainingMessageOfLastField(log_out_request); + + EXPECT_EQ(parent_msg, &session_info); +} + +TEST(FieldPathTest, + GetContainingMessageReturnsNonNullForUnsetMiddleSingularField) { + LogOutUserRequest log_out_request; + const FieldPath session_id_sink_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + + EXPECT_THAT( + session_id_sink_path.GetContainingMessageOfLastField(log_out_request), + NotNull()); +} + +TEST(FieldPathTest, + GetContainingMessageReturnsNullForMissingMiddleRepeatedField) { + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath message_containing_repeated_fields_sink_path = + GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + EXPECT_THAT( + message_containing_repeated_fields_sink_path + .GetContainingMessageOfLastField(message_containing_repeated_fields), + IsNull()); +} + +TEST(FieldPathTest, + MutableContainingMessageAddMissingRepeatedFieldAndReturnsNonNull) { + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath message_containing_repeated_fields_sink_path = + GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + EXPECT_THAT(message_containing_repeated_fields_sink_path + .MutableContainingMessageOfLastField( + message_containing_repeated_fields), + NotNull()); +} + +TEST(FieldPathTest, + MutableContainingMessageAtNonEmptyRepeatedFieldWillNotAddField) { + MessageContainingRepeatedFields message_containing_repeated_fields; + message_containing_repeated_fields.add_repeated_msg_field()->set_field2( + 0x123); + const FieldPath message_containing_repeated_fields_sink_path = + GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + (void)message_containing_repeated_fields_sink_path + .MutableContainingMessageOfLastField(message_containing_repeated_fields); + EXPECT_EQ( + message_containing_repeated_fields.GetReflection()->FieldSize( + message_containing_repeated_fields, + message_containing_repeated_fields.GetDescriptor()->FindFieldByName( + "repeated_msg_field")), + 1); +} + +TEST(FieldPathTest, MutableContainingMessageAtEmptyRepeatedFieldWillAddField) { + MessageContainingRepeatedFields message_containing_repeated_fields; + const FieldPath message_containing_repeated_fields_sink_path = + GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + (void)message_containing_repeated_fields_sink_path + .MutableContainingMessageOfLastField(message_containing_repeated_fields); + EXPECT_EQ( + message_containing_repeated_fields.GetReflection()->FieldSize( + message_containing_repeated_fields, + message_containing_repeated_fields.GetDescriptor()->FindFieldByName( + "repeated_msg_field")), + 1); +} + +TEST(FieldPathTest, + GetContainingMessageOfLastFieldReturnsFirstIndexOfRepeatedFields) { + MessageContainingRepeatedFields message; + message.add_repeated_msg_field(); + message.add_repeated_msg_field(); + + FieldPath field_path = GetFieldPath<MessageContainingRepeatedFields>( + "repeated_msg_field.field2"); + + const google::protobuf::Message* parent_msg = + field_path.GetContainingMessageOfLastField(message); + EXPECT_EQ(parent_msg, &message.repeated_msg_field(0)); +} + +TEST(FieldPathTest, MutableContainingMessageOfLastFieldReturnsMutableMessage) { + LogOutUserRequest log_out_request; + const FieldPath session_id_sink_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + + const LogOutInfo& log_out_info = log_out_request.log_out_info(); + const SessionInfo& session_info = log_out_info.session_info(); + + google::protobuf::Message* parent_msg = + session_id_sink_path.MutableContainingMessageOfLastField(log_out_request); + // MutableContainingMessageOfLastField might create new message so the address + // will be different. + EXPECT_NE(parent_msg, &session_info); +} + +TEST(FieldPathTest, GetLastFieldReturnsLastFieldOfThePath) { + const LogOutUserRequest log_out_request; + const FieldPath session_id_sink_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + + EXPECT_EQ(&session_id_sink_path.GetLastField(), + SessionInfo::GetDescriptor()->FindFieldByName("session_id")); +} + +TEST(FieldPathTest, + AppendFieldSuccessIfTheLastFieldIsTheAppendedFieldParentType) { + const FieldPath ground_truth = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info"); + FieldPath field_path = GetFieldPath<LogOutUserRequest>("log_out_info"); + field_path.AppendField( + *LogOutInfo::GetDescriptor()->FindFieldByName("session_info")); + EXPECT_EQ(ground_truth, field_path); +} + +TEST(FieldPathTest, + AppendFieldFailsIfTheLastFieldIsNotTheAppendedFieldParentType) { + FieldPath field_path = GetFieldPath<LogInUserResponse>("success"); + + EXPECT_DEATH_IF_SUPPORTED( + field_path.AppendField(field_path.GetLastField()), + "The current last field in the path must be a message " + "containing the appended field."); +} + +TEST(FieldPathTest, + AppendPathSuccessIfLastFieldOfAppendedPathIsContainedByTheCurrentPath) { + const FieldPath ground_truth = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + FieldPath field_path1 = GetFieldPath<LogOutUserRequest>("log_out_info"); + const FieldPath field_path2 = + GetFieldPath<LogOutInfo>("session_info.session_id"); + field_path1.AppendPath(field_path2); + EXPECT_EQ(field_path1, ground_truth); +} + +TEST(FieldPathTest, + AppendPathFailsIfLastFieldOfAppendedPathIsNotContainedByTheCurrentPath) { + FieldPath field_path1 = GetFieldPath<LogOutUserRequest>("log_out_info"); + const FieldPath field_path2 = GetFieldPath<SessionInfo>("session_id"); + + EXPECT_DEATH_IF_SUPPORTED( + field_path1.AppendPath(field_path2), + "The current last field in the path must be a message " + "containing the first field in the other path."); +} + +TEST(FieldPathTest, CollectAllFieldsReturnAllInnerFields) { + std::vector<FieldPath> all_fields = + CollectAllFields(*LogOutUserRequest::descriptor()); + + EXPECT_THAT(all_fields, + UnorderedElementsAre( + FieldPathAsString("log_out_info"), + FieldPathAsString("log_out_info.session_info"), + FieldPathAsString("log_out_info.session_id"), + FieldPathAsString("log_out_info.session_info.session_id"))); +} + +TEST(FieldPathTest, CollectAllFieldsSupportsGroupField) { + std::vector<FieldPath> all_fields = + CollectAllFields(*MessageWithGroup::descriptor()); + + EXPECT_THAT(all_fields, + UnorderedElementsAre(FieldPathAsString("groupfield"), + FieldPathAsString("groupfield.field1"), + FieldPathAsString("groupfield.field2"))); +} + +TEST(FieldPathTest, CopyFieldsSupportsGroupField) { + MessageWithGroup message_with_group1, message_with_group2; + + message_with_group1.mutable_groupfield()->set_field1(123); + CopyField(GetFieldPath<MessageWithGroup>("groupfield"), message_with_group1, + GetFieldPath<MessageWithGroup>("groupfield"), message_with_group2); + EXPECT_EQ(message_with_group2.groupfield().field1(), 123); +} + +TEST(FieldPathTest, CollectAllFieldsInRecursiveMessageSkipVisitedMessages) { + std::vector<FieldPath> all_fields = + CollectAllFields(*RecursiveNode::descriptor()); + EXPECT_THAT(all_fields, ElementsAre(FieldPathAsString("value"), + FieldPathAsString("children"))); +} + +TEST(FieldPathTest, ToStringConcatAllFieldPartWithDot) { + const FieldPath field_path = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + EXPECT_EQ(field_path.ToString(), "log_out_info.session_info.session_id"); +} + +TEST(FieldPathTest, AlternativeToReturnsTrueForFieldsWithinTheSameOneOf) { + EXPECT_TRUE( + AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v1"), + GetFieldPath<OneOfMessage>("oneof1_field2"))); + EXPECT_TRUE( + AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v2"), + GetFieldPath<OneOfMessage>("oneof1_field2"))); +} + +TEST(FieldPathTest, AlternativeToReturnsFalseForFieldsNotWithinTheSameOneOf) { + EXPECT_FALSE( + AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1"), + GetFieldPath<OneOfMessage>("oneof1_field1.v1"))); + EXPECT_FALSE( + AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field1.v1"), + GetFieldPath<OneOfMessage>("oneof1_field1.v1"))); + EXPECT_FALSE( + AreOneOfAltearnatives(GetFieldPath<OneOfMessage>("oneof1_field2"), + GetFieldPath<OneOfMessage>("oneof2_field2"))); +} + +} // namespace + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_executor.cc b/rpc_fuzzing/rpc_executor.cc new file mode 100644 index 0000000..c00dc2d --- /dev/null +++ b/rpc_fuzzing/rpc_executor.cc
@@ -0,0 +1,58 @@ +#include "./rpc_fuzzing/rpc_executor.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "./fuzztest/internal/logging.h" + +namespace fuzztest::internal { + +absl::Status RpcExecutor::Execute(RpcSequence& sequence) { + responses_.clear(); + responses_.reserve(sequence.size()); + absl::Status status; + for (RpcNode& node : sequence) { + status = ExecuteOne(node); + if (!status.ok()) { + return status; + } + } + return status; +} + +const google::protobuf::Message& RpcExecutor::GetRpcNodeResponse(RpcNodeID node_id) { + FUZZTEST_INTERNAL_CHECK(node_id < responses_.size(), + "The dependency source doesn't exist!"); + return *responses_[node_id]; +} + +std::vector<const google::protobuf::Message*> RpcExecutor::GetResponses() const { + std::vector<const google::protobuf::Message*> responses; + responses.reserve(responses_.size()); + for (const std::unique_ptr<google::protobuf::Message>& response : responses_) { + responses.push_back(response.get()); + } + return responses; +} + +absl::Status RpcExecutor::ExecuteOne(RpcNode& node) { + for (const RpcDataFlowEdge& dep : node.dependencies()) { + // Assign value to dynamic fields. + const google::protobuf::Message& save_response = GetRpcNodeResponse(dep.from_node_id); + CopyField(dep.from_field, save_response, dep.to_field, node.request()); + } + + absl::StatusOr<std::unique_ptr<google::protobuf::Message>> response = + stub_->CallMethod(node.method(), node.request()); + if (!response.ok()) return std::move(response).status(); + responses_.push_back(*std::move(response)); + return absl::OkStatus(); +} + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_executor.h b/rpc_fuzzing/rpc_executor.h new file mode 100644 index 0000000..31463ab --- /dev/null +++ b/rpc_fuzzing/rpc_executor.h
@@ -0,0 +1,38 @@ +#ifndef FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_ +#define FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_ + +#include <memory> +#include <vector> + +#include "google/protobuf/message.h" +#include "absl/log/die_if_null.h" +#include "./rpc_fuzzing/rpc_sequence.h" +#include "./rpc_fuzzing/rpc_stub.h" +namespace fuzztest::internal { + +// The RpcExecutor fills the dynamic values from the previous responses +// according to the dependency in the RpcSequence and sends modified requests to +// the fuzzed service. +class RpcExecutor { + public: + RpcExecutor(RpcStub* stub) : stub_(ABSL_DIE_IF_NULL(stub)) {} + // Send the rpc calls specified in `sequence`. The requests of `sequence` will + // be modified according to the dependency. Every time this method is called, + // the previous responses will be cleared. + absl::Status Execute(RpcSequence& sequence); + // Get the responses of the last executed sequence. + std::vector<const google::protobuf::Message*> GetResponses() const; + + private: + // Fill the dynamic fields in the request, execute a single RPC call, saves + // and returns the status of the rpc. + absl::Status ExecuteOne(RpcNode& node); + const google::protobuf::Message& GetRpcNodeResponse(RpcNodeID node_id); + + std::vector<std::unique_ptr<google::protobuf::Message>> responses_; + RpcStub* stub_; +}; + +} // namespace fuzztest::internal + +#endif // FUZZTEST_RPC_FUZZING_RPC_EXECUTOR_H_
diff --git a/rpc_fuzzing/rpc_executor_test.cc b/rpc_fuzzing/rpc_executor_test.cc new file mode 100644 index 0000000..178daa2 --- /dev/null +++ b/rpc_fuzzing/rpc_executor_test.cc
@@ -0,0 +1,86 @@ +#include "./rpc_fuzzing/rpc_executor.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "./domain_tests/domain_testing.h" +#include "./rpc_fuzzing/grpc_stub.h" +#include "./rpc_fuzzing/proto_field_path.h" +#include "./rpc_fuzzing/rpc_sequence.h" +#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" +#include "grpcpp//security/server_credentials.h" +#include "grpcpp//server.h" +#include "grpcpp//server_builder.h" +#include "grpcpp//support/channel_arguments.h" + +namespace fuzztest::internal { + +namespace { + +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsTrue; +using ::testing::Ne; +using ::testing::NotNull; +using ::testing::Pointee; +using ::testing::ResultOf; + +const google::protobuf::FieldDescriptor* GetField(const google::protobuf::Message& message, + absl::string_view field_name) { + return message.GetDescriptor()->FindFieldByName(field_name); +} + +class RpcExecutorGrpcTest : public ::testing::Test { + protected: + RpcExecutorGrpcTest() + : server_(grpc::ServerBuilder() + .RegisterService(&mini_blogger_service_) + .BuildAndStart()), + stub_(server_->InProcessChannel(grpc::ChannelArguments())) {} + + ~RpcExecutorGrpcTest() override { server_->Shutdown(); } + + RpcNode GetRegisterUserRpcNode(absl::string_view user_name, + absl::string_view email, + absl::string_view password) { + auto request = std::make_unique<RegisterUserRequest>(); + request->set_user_name(user_name); + request->set_password(password); + request->set_email(password); + const google::protobuf::MethodDescriptor* method = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + CHECK(method != nullptr); + return RpcNode(*method, std::move(request)); + } + + MiniBloggerGrpcService mini_blogger_service_; + std::unique_ptr<grpc::Server> server_ = nullptr; + GrpcStub stub_; +}; + +TEST_F(RpcExecutorGrpcTest, ExecutorCorrectlyExecutesRpcCalls) { + RpcNode node = GetRegisterUserRpcNode("test", "test_email@gmail.com", "123"); + RpcSequence sequence = {node}; + RpcExecutor executor(&stub_); + ASSERT_OK(executor.Execute(sequence)); + EXPECT_THAT(executor.GetResponses(), + ElementsAre(Pointee(ResultOf( + [](const google::protobuf::Message& response) { + return response.GetReflection()->GetBool( + response, GetField(response, "success")); + }, + IsTrue())))); +} + +} // namespace +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_potential_dfg.cc b/rpc_fuzzing/rpc_potential_dfg.cc new file mode 100644 index 0000000..ed9b04b --- /dev/null +++ b/rpc_fuzzing/rpc_potential_dfg.cc
@@ -0,0 +1,155 @@ +#include "./rpc_fuzzing/rpc_potential_dfg.h" + +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "./fuzztest/internal/logging.h" +#include "./rpc_fuzzing/proto_field_path.h" + +namespace fuzztest::internal { + +namespace { + +// Get the type name of a field. If it is a message type, then return its +// concrete message type name. +std::string_view GetTypeNameForField(const google::protobuf::FieldDescriptor& field) { + if (field.type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + return field.message_type()->name(); + } else { + return field.type_name(); + } +} + +// Describe a field path in a specific method's request or response. +struct DetailedFieldInfo { + // Use pointer to keep it copy assignable. + const google::protobuf::MethodDescriptor* method; + bool in_request; + FieldPath field_path; +}; + +// Collect either all the fields in the response/request of a method. +std::vector<DetailedFieldInfo> CollectFieldInfoInMethod( + const google::protobuf::MethodDescriptor& method, bool collect_request) { + const google::protobuf::Descriptor& message_descriptor = + collect_request ? *method.input_type() : *method.output_type(); + std::vector<FieldPath> all_field_paths = CollectAllFields(message_descriptor); + std::vector<DetailedFieldInfo> result; + result.reserve(all_field_paths.size()); + for (const auto& field_path : all_field_paths) { + result.push_back(DetailedFieldInfo{&method, collect_request, field_path}); + } + return result; +} + +std::vector<DetailedFieldInfo> CollectDefinitions( + const google::protobuf::MethodDescriptor& method) { + return CollectFieldInfoInMethod(method, /*collect_request=*/false); +} + +std::vector<DetailedFieldInfo> CollectUses( + const google::protobuf::MethodDescriptor& method) { + return CollectFieldInfoInMethod(method, /*collect_request=*/true); +} + +bool InSameMethod(const DetailedFieldInfo& a, const DetailedFieldInfo& b) { + return a.method->full_name() == b.method->full_name(); +} + +bool HasSameNameAndType(const DetailedFieldInfo& a, + const DetailedFieldInfo& b) { + const google::protobuf::FieldDescriptor& a_field = a.field_path.GetLastField(); + const google::protobuf::FieldDescriptor& b_field = b.field_path.GetLastField(); + return a_field.name() == b_field.name() && + GetTypeNameForField(a_field) == GetTypeNameForField(b_field); +} + +} // namespace + +void RpcPotentialDfgNode::AddDependency( + const google::protobuf::MethodDescriptor& source_method, + const FieldPath& source_field, const FieldPath& sink_field) { + dependencies_[sink_field].push_back( + PotentialDependencySource{&source_method, source_field}); +} + +void RpcPotentialDataFlowGraph::AddDependency( + const google::protobuf::MethodDescriptor& source_method, + const FieldPath& source_field, const google::protobuf::MethodDescriptor& sink_method, + const FieldPath& sink_field) { + FUZZTEST_INTERNAL_CHECK_PRECONDITION( + nodes_.find(&source_method) != nodes_.end(), "No such source method!"); + FUZZTEST_INTERNAL_CHECK_PRECONDITION( + nodes_.find(&sink_method) != nodes_.end(), "No such sink method!"); + RpcPotentialDfgNode& sink_node = nodes_.find(&sink_method)->second; + sink_node.AddDependency(source_method, source_field, sink_field); +} + +RpcPotentialDataFlowGraph::RpcPotentialDataFlowGraph( + const google::protobuf::ServiceDescriptor& service) { + for (int i = 0; i < service.method_count(); ++i) { + const google::protobuf::MethodDescriptor& method = *service.method(i); + nodes_.emplace(&method, RpcPotentialDfgNode(method)); + } +} + +RpcPotentialDataFlowGraph RpcPotentialDataFlowGraph::Create( + const google::protobuf::ServiceDescriptor& service_descriptor) { + // Collect all the definitions and uses. + std::vector<DetailedFieldInfo> all_definitions, all_uses; + for (int i = 0; i < service_descriptor.method_count(); ++i) { + const google::protobuf::MethodDescriptor& method = *service_descriptor.method(i); + std::vector<DetailedFieldInfo> definitions = CollectDefinitions(method); + std::vector<DetailedFieldInfo> uses = CollectUses(method); + all_definitions.insert(all_definitions.end(), definitions.begin(), + definitions.end()); + all_uses.insert(all_uses.end(), uses.begin(), uses.end()); + } + + // For each pair of definition and use with same name and type, add a + // dependency (edge) in the graph. + RpcPotentialDataFlowGraph graph(service_descriptor); + for (const DetailedFieldInfo& define_field_info : all_definitions) { + for (const DetailedFieldInfo& use_field_info : all_uses) { + if (!InSameMethod(define_field_info, use_field_info) && + HasSameNameAndType(define_field_info, use_field_info)) { + graph.AddDependency(*define_field_info.method, + define_field_info.field_path, + *use_field_info.method, use_field_info.field_path); + } + } + } + return graph; +} + +const std::vector<RpcPotentialDfgNode::PotentialDependencySource>& +RpcPotentialDfgNode::GetDependencies(const FieldPath& sink) const { + FUZZTEST_INTERNAL_CHECK_PRECONDITION( + dependencies_.find(sink) != dependencies_.end(), "No such sink!"); + return dependencies_.at(sink); +} + +const RpcPotentialDfgNode& RpcPotentialDataFlowGraph::GetNode( + const google::protobuf::MethodDescriptor& method) const { + FUZZTEST_INTERNAL_CHECK_PRECONDITION(nodes_.find(&method) != nodes_.end(), + "No such method!"); + return nodes_.at(&method); +} + +const google::protobuf::ServiceDescriptor& GetServiceDescriptorByServiceName( + absl::string_view service_name) { + const google::protobuf::ServiceDescriptor* service = + google::protobuf::DescriptorPool::generated_pool()->FindServiceByName( + service_name.data()); + FUZZTEST_INTERNAL_CHECK(service != nullptr, + std::string(service_name) + " Service not found!"); + return *service; +} + +RpcPotentialDataFlowGraph CreatePotentialDfgByServiceName( + absl::string_view service_name) { + return RpcPotentialDataFlowGraph::Create( + GetServiceDescriptorByServiceName(service_name)); +} + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_potential_dfg.h b/rpc_fuzzing/rpc_potential_dfg.h new file mode 100644 index 0000000..ace3f37 --- /dev/null +++ b/rpc_fuzzing/rpc_potential_dfg.h
@@ -0,0 +1,129 @@ +#ifndef FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_ +#define FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_ + +#include <type_traits> +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/service.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "./fuzztest/internal/logging.h" +#include "./rpc_fuzzing/proto_field_path.h" + +namespace fuzztest::internal { + +// A `RpcPotentialDfgNode` describes the potential dependencies for the +// fields in a request of a rpc method. +class RpcPotentialDfgNode { + public: + RpcPotentialDfgNode(const google::protobuf::MethodDescriptor& method) + : method_(method) {} + + // Establish a dependency from the `source_field` to the `sink_field` of + // current method. + void AddDependency(const google::protobuf::MethodDescriptor& source_method, + const FieldPath& source_field, + const FieldPath& sink_field); + + // Get the method representing by the node. + const google::protobuf::MethodDescriptor& GetMethod() const { return method_; } + + // Source of dependency, which should be in the response of a method. + struct PotentialDependencySource { + // `method` will never be null. Use a pointer to keep it copy assignable. + const google::protobuf::MethodDescriptor* method; + FieldPath field_path; + }; + + const absl::flat_hash_map<FieldPath, std::vector<PotentialDependencySource>>& + GetAllDependencies() const { + return dependencies_; + } + + const std::vector<PotentialDependencySource>& GetDependencies( + const FieldPath& sink) const; + bool HasDependency() const { return !dependencies_.empty(); } + + private: + const google::protobuf::MethodDescriptor& method_; + absl::flat_hash_map<FieldPath /*sink*/, + std::vector<PotentialDependencySource> /*source*/> + dependencies_; +}; + +// A `RpcPotentialDataFlowGraph` describes all the potential dependencies among +// methods in a service. Each method is represented as a node in the graph. +class RpcPotentialDataFlowGraph { + public: + // Get the node representing `method`. + const RpcPotentialDfgNode& GetNode( + const google::protobuf::MethodDescriptor& method) const; + + // Create an `RpcPotentialDfg` with each method in the service being + // a node, and add dependencies between the methods. + static RpcPotentialDataFlowGraph Create( + const google::protobuf::ServiceDescriptor& service); + + private: + RpcPotentialDataFlowGraph(const google::protobuf::ServiceDescriptor& service); + // Add a dependency on the field from the response of `source_method` to the + // request of `sink_method`. + void AddDependency(const google::protobuf::MethodDescriptor& source_method, + const FieldPath& source_field, + const google::protobuf::MethodDescriptor& sink_method, + const FieldPath& sink_field); + + absl::flat_hash_map<const google::protobuf::MethodDescriptor*, RpcPotentialDfgNode> + nodes_; +}; + +template <typename T, typename = void> +struct is_stubby_service : std::false_type {}; + +template <typename ServiceT> +struct is_stubby_service<ServiceT, std::void_t<decltype(ServiceT::descriptor)>> + : std::true_type {}; + +template <typename T, typename = void> +struct is_grpc_service : std::false_type {}; + +template <typename ServiceT> +struct is_grpc_service<ServiceT, + std::void_t<decltype(ServiceT::service_full_name)>> + : std::true_type {}; + +// Get the service descriptor based on the service type. Currently support +// protobuf services and stubby services. +template <typename ServiceT> +const google::protobuf::ServiceDescriptor& GetServiceDescriptor() { + if constexpr (std::is_same_v<ServiceT, google::protobuf::Service>) { + return *ServiceT::descriptor(); + } else if constexpr (is_stubby_service<ServiceT>::value) { + return *google::protobuf::DescriptorPool::generated_pool()->FindServiceByName( + ServiceT::descriptor.full_name()); + } else if constexpr (is_grpc_service<ServiceT>::value) { + return *google::protobuf::DescriptorPool::generated_pool()->FindServiceByName( + ServiceT::service_full_name()); + } else { + FUZZTEST_INTERNAL_CHECK(false, "Unsupported service type!"); + } +} + +// Get the service descriptor based on the fully qualified service name such as +// "package.ServiceName". +const google::protobuf::ServiceDescriptor& GetServiceDescriptorByServiceName( + absl::string_view service_name); + +// Create an potential data flow graph for a service type. +template <typename ServiceT> +RpcPotentialDataFlowGraph CreatePotentialDfg() { + return RpcPotentialDataFlowGraph::Create(GetServiceDescriptor<ServiceT>()); +} + +RpcPotentialDataFlowGraph CreatePotentialDfgByServiceName( + absl::string_view service_name); + +} // namespace fuzztest::internal + +#endif // FUZZTEST_RPC_FUZZING_RPC_POTENTIAL_DFG_H_
diff --git a/rpc_fuzzing/rpc_potential_dfg_test.cc b/rpc_fuzzing/rpc_potential_dfg_test.cc new file mode 100644 index 0000000..c5c54a1 --- /dev/null +++ b/rpc_fuzzing/rpc_potential_dfg_test.cc
@@ -0,0 +1,104 @@ +#include "./rpc_fuzzing/rpc_potential_dfg.h" + +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "./rpc_fuzzing/proto_field_path.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" +#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h" + +namespace fuzztest::internal { + +namespace { + +using ::testing::FieldsAre; +using ::testing::UnorderedElementsAre; + +class RpcPotentialDfgTest : public ::testing::Test { + protected: + void SetUp() override { + const google::protobuf::DescriptorPool* pool = + google::protobuf::DescriptorPool::generated_pool(); + mini_blogger_serivce_ = + pool->FindServiceByName("fuzztest.internal.MiniBlogger"); + ASSERT_TRUE(mini_blogger_serivce_ != nullptr); + log_out_user_method_ = + mini_blogger_serivce_->FindMethodByName("LogOutUser"); + ASSERT_TRUE(log_out_user_method_ != nullptr); + log_in_user_method_ = mini_blogger_serivce_->FindMethodByName("LogInUser"); + ASSERT_TRUE(log_in_user_method_ != nullptr); + register_user_method_ = + mini_blogger_serivce_->FindMethodByName("RegisterUser"); + ASSERT_TRUE(register_user_method_ != nullptr); + get_user_posts_method_ = + mini_blogger_serivce_->FindMethodByName("GetUserPosts"); + ASSERT_TRUE(get_user_posts_method_ != nullptr); + } + const google::protobuf::ServiceDescriptor* mini_blogger_serivce_; + const google::protobuf::MethodDescriptor* log_out_user_method_; + const google::protobuf::MethodDescriptor* log_in_user_method_; + const google::protobuf::MethodDescriptor* register_user_method_; + const google::protobuf::MethodDescriptor* get_user_posts_method_; +}; + +TEST_F(RpcPotentialDfgTest, + NodesNotDependingOnOthersHaveNoPotentialDependencies) { + RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>(); + + const RpcPotentialDfgNode& register_user_node = + dfg.GetNode(*register_user_method_); + // RegisterUser doesn't depend on others. + EXPECT_FALSE(register_user_node.HasDependency()); + + const RpcPotentialDfgNode& log_in_user_node = + dfg.GetNode(*log_in_user_method_); + // LogInUser doesn't depend on others. + EXPECT_FALSE(log_in_user_node.HasDependency()); +} + +TEST_F(RpcPotentialDfgTest, DependneciesSetUpBasedOnNameAndType) { + RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>(); + const RpcPotentialDfgNode& get_user_posts_node = + dfg.GetNode(*get_user_posts_method_); + // GetUserPostsRequest.session_id depends on LogInUserResponse.session_id. + EXPECT_EQ(get_user_posts_node.GetAllDependencies().size(), 1); + + FieldPath get_user_posts_request_field = + GetFieldPath<GetUserPostsRequest>("session_id"); + std::vector<RpcPotentialDfgNode::PotentialDependencySource> all_dep_sources = + get_user_posts_node.GetDependencies(get_user_posts_request_field); + + // Depends on the session_id of the response in `LogInUser`. + EXPECT_THAT( + all_dep_sources, + UnorderedElementsAre(FieldsAre( + log_in_user_method_, GetFieldPath<LogInUserResponse>("session_id")))); +} + +TEST_F(RpcPotentialDfgTest, DependneciesSetUpOnInnerFields) { + RpcPotentialDataFlowGraph dfg = CreatePotentialDfg<MiniBlogger>(); + const RpcPotentialDfgNode& log_out_user_node = + dfg.GetNode(*log_out_user_method_); + // LogOutUserRequest.log_out_info.session_info.session_id and + // LogOutUserRequest.log_out_info.session_id depends on + // LogInUserResponse.session_id. + EXPECT_EQ(log_out_user_node.GetAllDependencies().size(), 2); + + FieldPath log_out_user_request_field = + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id"); + std::vector<RpcPotentialDfgNode::PotentialDependencySource> all_dep_sources = + log_out_user_node.GetDependencies(log_out_user_request_field); + EXPECT_EQ(all_dep_sources.size(), 1); + + // Depends on the session_id of the response in `LogInUser`. + EXPECT_THAT( + all_dep_sources, + UnorderedElementsAre(FieldsAre( + log_in_user_method_, GetFieldPath<LogInUserResponse>("session_id")))); +} + +} // namespace + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_sequence.cc b/rpc_fuzzing/rpc_sequence.cc new file mode 100644 index 0000000..cb69534 --- /dev/null +++ b/rpc_fuzzing/rpc_sequence.cc
@@ -0,0 +1,184 @@ +#include "./rpc_fuzzing/rpc_sequence.h" + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "./fuzztest/internal/logging.h" + +namespace fuzztest::internal { + +namespace { + +std::vector<int> GetShuffledIndices(size_t size, absl::BitGenRef gen) { + std::vector<int> shuffled_indices(size); + std::iota(shuffled_indices.begin(), shuffled_indices.end(), 0); + std::shuffle(shuffled_indices.begin(), shuffled_indices.end(), gen); + return shuffled_indices; +} + +std::vector<RpcNodeID> GetShuffledNodeIdsInGraph(const RpcDataFlowGraph& graph, + absl::BitGenRef gen) { + std::vector<RpcNodeID> shuffled_node_ids; + shuffled_node_ids.reserve(graph.NodeNum()); + for (const auto& [node_id, unused_node] : graph.GetAllNodes()) { + shuffled_node_ids.push_back(node_id); + } + std::shuffle(shuffled_node_ids.begin(), shuffled_node_ids.end(), gen); + return shuffled_node_ids; +} + +void RemapRpcNodeIdToSequenceIndex( + RpcSequence& sequence, + const absl::flat_hash_map<RpcNodeID, RpcNodeID>& id_map) { + for (RpcNode& node : sequence) { + for (RpcDataFlowEdge& edge : node.dependencies()) { + edge.from_node_id = id_map.at(edge.from_node_id); + } + } +} + +void CollectTopologicalSortingOrder(const RpcDataFlowGraph& graph, + RpcNodeID current_id, absl::BitGenRef gen, + absl::flat_hash_set<RpcNodeID>& visited, + std::vector<RpcNodeID>& sorted_ids) { + if (visited.contains(current_id)) return; + + visited.insert(current_id); + const RpcNode& node = graph.GetNode(current_id); + std::vector<RpcNodeID> shuffled_edge_indices = + GetShuffledIndices(node.dependencies().size(), gen); + for (int edge_index : shuffled_edge_indices) { + CollectTopologicalSortingOrder(graph, + node.dependencies()[edge_index].from_node_id, + gen, visited, sorted_ids); + } + sorted_ids.push_back(current_id); +} + +std::vector<RpcNodeID> GetRandomTopologicalOrdering( + const RpcDataFlowGraph& graph) { + std::vector<RpcNodeID> sorted_node_ids; + sorted_node_ids.reserve(graph.NodeNum()); + absl::flat_hash_set<RpcNodeID> visited; + absl::BitGen gen; + std::vector<RpcNodeID> shuffled_node_ids = + GetShuffledNodeIdsInGraph(graph, gen); + for (RpcNodeID node_id : shuffled_node_ids) { + CollectTopologicalSortingOrder(graph, node_id, gen, visited, + sorted_node_ids); + } + FUZZTEST_INTERNAL_CHECK( + sorted_node_ids.size() == graph.NodeNum(), + "Topological sort results in different number of nodes!"); + return sorted_node_ids; +} + +RpcSequence SequentializeGraphByOrder( + const RpcDataFlowGraph& graph, const std::vector<RpcNodeID>& node_order) { + RpcSequence result; + result.reserve(graph.NodeNum()); + absl::flat_hash_map<RpcNodeID /* original node id */, + RpcNodeID /* sequence index */> + id_map; + for (size_t seq_idx = 0; seq_idx < node_order.size(); ++seq_idx) { + RpcNodeID original_node_id = node_order[seq_idx]; + id_map[original_node_id] = seq_idx; + result.push_back(graph.GetNode(original_node_id)); + } + RemapRpcNodeIdToSequenceIndex(result, id_map); + return result; +} + +} // namespace + +RpcNode::RpcNode(const RpcNode& other) + : method_(other.method_), + request_(absl::WrapUnique(other.request_->New())), + dependencies_(other.dependencies_) { + request_->CopyFrom(*other.request_); +} + +RpcNode& RpcNode::operator=(const RpcNode& other) { + method_ = other.method_; + request_ = absl::WrapUnique(other.request_->New()); + request_->CopyFrom(*other.request_); + dependencies_ = other.dependencies_; + return *this; +} + +bool operator==(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) { + return lhs.from_node_id == rhs.from_node_id && + lhs.from_field == rhs.from_field && lhs.to_field == rhs.to_field; +} + +bool operator!=(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) { + return !(lhs == rhs); +} + +bool operator==(const RpcNode& lhs, const RpcNode& rhs) { + return lhs.method_->full_name() == rhs.method_->full_name() && + google::protobuf::util::MessageDifferencer::Equals(*lhs.request_, + *rhs.request_) && + lhs.dependencies_ == rhs.dependencies_; +} + +bool operator!=(const RpcNode& lhs, const RpcNode& rhs) { + return !(lhs == rhs); +} + +void RpcDataFlowGraph::AddNode(RpcNodeID id, RpcNode rpc_node) { + InvalidateNodeOrder(); + rpc_nodes_.emplace(id, rpc_node); +} + +void RpcDataFlowGraph::RemoveNode(RpcNodeID id) { + InvalidateNodeOrder(); + rpc_nodes_.erase(id); +} + +const RpcNode& RpcDataFlowGraph::GetNode(RpcNodeID id) const { + FUZZTEST_INTERNAL_CHECK(rpc_nodes_.contains(id), "Invalid id!"); + return rpc_nodes_.at(id); +} + +size_t RpcDataFlowGraph::NodeNum() const { return rpc_nodes_.size(); } + +RpcDataFlowGraph RpcDataFlowGraph::FromSequence(const RpcSequence& sequence) { + RpcDataFlowGraph graph; + for (int node_idx = 0; node_idx < sequence.size(); ++node_idx) { + graph.AddNode(node_idx, sequence[node_idx]); + } + graph.node_order_ = std::vector<RpcNodeID>(sequence.size()); + std::iota(graph.node_order_->begin(), graph.node_order_->end(), 0); + return graph; +} + +RpcSequence RpcDataFlowGraph::GetSequence() const { + if (!node_order_.has_value()) { + node_order_ = GetRandomTopologicalOrdering(*this); + } + return SequentializeGraphByOrder(*this, *node_order_); +} + +std::vector<RpcNodeID>& RpcDataFlowGraph::GetOrderedNodeIds() const { + if (!node_order_.has_value()) { + node_order_ = GetRandomTopologicalOrdering(*this); + } + return *node_order_; +} + +bool operator==(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) { + return lhs.GetSequence() == rhs.GetSequence(); +} + +bool operator!=(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) { + return !(lhs == rhs); +} + +void RpcDataFlowGraph::RandomizeTopologicalOrdering() { InvalidateNodeOrder(); } + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_sequence.h b/rpc_fuzzing/rpc_sequence.h new file mode 100644 index 0000000..b014022 --- /dev/null +++ b/rpc_fuzzing/rpc_sequence.h
@@ -0,0 +1,116 @@ +#ifndef FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_ +#define FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_ + +#include <memory> +#include <optional> +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_map.h" +#include "./rpc_fuzzing/proto_field_path.h" + +namespace fuzztest::internal { + +// An RPC call is identified by an id because we might have multiple calls to +// the same rpc method. In the sequence, the id is also the index of the node. +using RpcNodeID = std::int32_t; + +// Represents a data-flow edge between two RPC calls. +struct RpcDataFlowEdge { + RpcNodeID from_node_id = -1; + // The message field in the response. + FieldPath from_field; + // The dynamic field in the request. + FieldPath to_field; +}; + +bool operator==(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs); +bool operator!=(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs); + +// Represents a remote procedure call (a request). Parts of the request that are +// coming from a previous response are represented with a data flow edge +// (`RpcDataFlowEdge`). +class RpcNode { + public: + RpcNode(const google::protobuf::MethodDescriptor& method, + std::unique_ptr<google::protobuf::Message> request) + : method_(&method), request_(std::move(request)) {} + RpcNode(const RpcNode&); + RpcNode(RpcNode&&) noexcept; + RpcNode& operator=(const RpcNode&); + RpcNode& operator=(RpcNode&&) = default; + const google::protobuf::MethodDescriptor& method() const { return *method_; } + google::protobuf::Message& request() { return *request_; } + const google::protobuf::Message& request() const { return *request_; } + void AddDependency(RpcDataFlowEdge dep) { dependencies_.push_back(dep); } + const std::vector<RpcDataFlowEdge>& dependencies() const { + return dependencies_; + } + std::vector<RpcDataFlowEdge>& dependencies() { return dependencies_; } + + friend bool operator==(const RpcNode& lhs, const RpcNode& rhs); + + private: + // The `method_` is ensured to be non-null. + const google::protobuf::MethodDescriptor* method_; + std::unique_ptr<google::protobuf::Message> request_; + std::vector<RpcDataFlowEdge> dependencies_; +}; + +bool operator==(const RpcNode& lhs, const RpcNode& rhs); +bool operator!=(const RpcNode& lhs, const RpcNode& rhs); + +inline RpcNode::RpcNode(RpcNode&&) noexcept = default; + +// The RpcSequence represents a topologically sorted RpcDataFlowGraph. +using RpcSequence = std::vector<RpcNode>; + +// The RpcDataFlowGraph represents a "remote procedure call session", i.e., a +// set of RPC method calls and the data-flow dependencies between them. +class RpcDataFlowGraph { + public: + void AddNode(RpcNodeID id, RpcNode rpc_node); + void RemoveNode(RpcNodeID id); + const RpcNode& GetNode(RpcNodeID id) const; + size_t NodeNum() const; + + absl::flat_hash_map<RpcNodeID, RpcNode>& GetAllNodes() { return rpc_nodes_; } + const absl::flat_hash_map<RpcNodeID, RpcNode>& GetAllNodes() const { + return rpc_nodes_; + } + + // Returns a topologically sorted sequence representation of the graph. Note + // that there are multiple possible topological orderings, from which it + // returns a random one. Multiple calls to the method return the same ordering + // (unless RandomizeTopologicalOrdering() is called). + RpcSequence GetSequence() const; + + // Converts a sequence to a graph with the sequence node ordering. + static RpcDataFlowGraph FromSequence(const RpcSequence& sequence); + + // Randomizes the topological ordering of the nodes. + void RandomizeTopologicalOrdering(); + + // Returns the ordering of the nodes corresponding to that returned by + // `GetSequence`. + std::vector<RpcNodeID>& GetOrderedNodeIds() const; + + private: + // Invalidates the node order so that if the order is needed again it will be + // re-computed on demand. This optimization avoids unnecessary recomputation + // every time the graph is changed because the recomputation can be costly. + void InvalidateNodeOrder() const { node_order_.reset(); } + absl::flat_hash_map<RpcNodeID, RpcNode> rpc_nodes_; + // `GetSequence` might change the order so we need `mutable` for `GetSequence` + // to be used in `FromValue` in the RpcSession domain. + mutable std::optional<std::vector<RpcNodeID>> node_order_; +}; + +// Returns true if the graphs have the same nodes and ordering. +bool operator==(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs); +bool operator!=(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs); + +} // namespace fuzztest::internal + +#endif // FUZZTEST_RPC_FUZZING_RPC_SEQUENCE_H_
diff --git a/rpc_fuzzing/rpc_sequence_test.cc b/rpc_fuzzing/rpc_sequence_test.cc new file mode 100644 index 0000000..768d630 --- /dev/null +++ b/rpc_fuzzing/rpc_sequence_test.cc
@@ -0,0 +1,280 @@ +#include "./rpc_fuzzing/rpc_sequence.h" + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/util/message_differencer.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "./rpc_fuzzing/proto_field_path.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" + +namespace fuzztest::internal { + +namespace { + +using ::testing::Address; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::Property; +using ::testing::UnorderedElementsAreArray; + +TEST(RpcNodeTest, AssignOperator) { + constexpr RpcNodeID kFromNodeId = 1; + const google::protobuf::MethodDescriptor* method_descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + CHECK(method_descriptor != nullptr); + RpcNode node1 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + node1.AddDependency(RpcDataFlowEdge{kFromNodeId}); + + RpcNode node2 = node1; + + // Check that the two objects are equal. + EXPECT_EQ(node1.method().full_name(), node2.method().full_name()); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(), + node2.request())); + EXPECT_THAT(node2.dependencies(), + ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId))); +} + +TEST(RpcNodeTest, MoveAssignOperator) { + constexpr RpcNodeID kFromNodeId = 1; + const google::protobuf::MethodDescriptor* method_descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + CHECK(method_descriptor != nullptr); + RpcNode node1 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + node1.AddDependency(RpcDataFlowEdge{kFromNodeId}); + + RpcNode clone = node1; + RpcNode node2 = std::move(clone); + + // Check that the two objects are equal. + EXPECT_EQ(node1.method().full_name(), node2.method().full_name()); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(), + node2.request())); + EXPECT_THAT(node2.dependencies(), + ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId))); +} + +TEST(RpcNodeTest, CopyConstructor) { + constexpr RpcNodeID kFromNodeId = 1; + const google::protobuf::MethodDescriptor* method_descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + CHECK(method_descriptor != nullptr); + RpcNode node1 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + node1.AddDependency(RpcDataFlowEdge{kFromNodeId}); + + RpcNode node2(node1); + + // Check that the two objects are equal. + EXPECT_EQ(node1.method().full_name(), node2.method().full_name()); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals(node1.request(), + node2.request())); + EXPECT_THAT(node2.dependencies(), + ElementsAre(Field(&RpcDataFlowEdge::from_node_id, kFromNodeId))); +} + +TEST(RpcNodeTest, EqualityTest) { + constexpr RpcNodeID kFromNodeId = 1; + const google::protobuf::MethodDescriptor* method_descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + ASSERT_NE(method_descriptor, nullptr); + RpcNode node1 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + node1.AddDependency(RpcDataFlowEdge{kFromNodeId}); + + RpcNode node2(node1); + + EXPECT_EQ(node1, node2); +} + +TEST(RpcNodeTest, InEqualityTest) { + constexpr RpcNodeID kFromNodeId = 1; + const google::protobuf::MethodDescriptor* method_descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + ASSERT_NE(method_descriptor, nullptr); + RpcNode node1 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + node1.AddDependency(RpcDataFlowEdge{kFromNodeId}); + + RpcNode node2 = + RpcNode(*method_descriptor, std::make_unique<RegisterUserRequest>()); + + EXPECT_NE(node1, node2); +} + +TEST(RpcDataFlowEdgeTest, EqualityTest) { + RpcDataFlowEdge edge1 = + RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<GetUserPostsRequest>("session_id")}; + RpcDataFlowEdge edge2 = + RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<GetUserPostsRequest>("session_id")}; + + EXPECT_EQ(edge1, edge2); +} + +TEST(RpcDataFlowEdgeTest, InEqualityTest) { + RpcDataFlowEdge edge1 = + RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<GetUserPostsRequest>("session_id")}; + RpcDataFlowEdge edge2 = RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}; + + EXPECT_NE(edge1, edge2); +} + +class RpcDataFlowGraphTest : public ::testing::Test { + protected: + void SetUp() override { + log_in_user_method_ = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.LogInUser"); + log_out_user_method_ = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.LogOutUser"); + register_user_method_ = + google::protobuf::DescriptorPool::generated_pool()->FindMethodByName( + "fuzztest.internal.MiniBlogger.RegisterUser"); + } + + RpcNode CreateLogInUserNode(RpcNodeID log_in_node_id) { + return RpcNode(*log_in_user_method_, std::make_unique<LogInUserRequest>()); + } + + RpcNode CreateRegisterUserNode(RpcNodeID log_out_node_id) { + return RpcNode(*register_user_method_, + std::make_unique<RegisterUserRequest>()); + } + + std::pair<RpcNode, RpcNode> CreateLogInOutUserNodePairWithDep( + RpcNodeID log_in_node_id) { + // Set up a node for LogInUser and LogOutUser. Establish a dependency from + // LogInUserResponse.session_id to + // LogOutUserRequest.log_out_info.session_info.session_id. + RpcNode log_in_user_node(*log_in_user_method_, + std::make_unique<LogInUserResponse>()); + RpcNode log_out_user_node(*log_out_user_method_, + std::make_unique<LogOutUserRequest>()); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + log_in_node_id, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>( + "log_out_info.session_info.session_id")}); + return std::make_pair(log_in_user_node, log_out_user_node); + } + + const google::protobuf::MethodDescriptor *log_in_user_method_, *log_out_user_method_, + *register_user_method_; + RpcNodeID log_in_node_id_ = 0x123; + RpcNodeID log_out_node_id_ = 0x234; +}; + +TEST_F(RpcDataFlowGraphTest, GetSequenceReturnTologicalSortedSequence) { + RpcDataFlowGraph graph; + auto [log_in_user_node, log_out_user_node] = + CreateLogInOutUserNodePairWithDep(log_in_node_id_); + // Add LogOutUser first, so that the topological sort should change its order. + graph.AddNode(log_out_node_id_, log_out_user_node); + graph.AddNode(log_in_node_id_, log_in_user_node); + + RpcSequence sequence = graph.GetSequence(); + EXPECT_THAT( + sequence, + ElementsAre(Property(&RpcNode::method, Address(log_in_user_method_)), + Property(&RpcNode::method, Address(log_out_user_method_)))); +} + +TEST_F(RpcDataFlowGraphTest, GetSequenceReMapAllTheRpcNodeIdToSequenceIndex) { + auto [log_in_user_node, log_out_user_node] = + CreateLogInOutUserNodePairWithDep(log_in_node_id_); + RpcDataFlowGraph graph; + graph.AddNode(log_in_node_id_, log_in_user_node); + graph.AddNode(log_out_node_id_, log_out_user_node); + + RpcSequence sequence = graph.GetSequence(); + EXPECT_EQ(&sequence[1].method(), log_out_user_method_); + EXPECT_THAT(sequence[1].dependencies(), + ElementsAre(Field(&RpcDataFlowEdge::from_node_id, + 0 /* remap `kNode1Id` to 0*/))); +} + +TEST_F(RpcDataFlowGraphTest, GetSequenceReturnRandomizedTopoSortedSequence) { + const std::string_view expected_methods_of_first_nodes[] = {"LogInUser", + "RegisterUser"}; + absl::flat_hash_set<std::string> methods_of_first_nodes; + for (int i = 0; i < 100; ++i) { + RpcDataFlowGraph graph; + graph.AddNode(0, CreateLogInUserNode(log_in_node_id_)); + graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_)); + methods_of_first_nodes.insert(graph.GetSequence()[0].method().name()); + } + + EXPECT_THAT(methods_of_first_nodes, + UnorderedElementsAreArray(expected_methods_of_first_nodes)); +} + +TEST_F(RpcDataFlowGraphTest, FromSequenceRecoversNodesAndOrdering) { + RpcSequence sequence = {CreateLogInUserNode(log_in_node_id_), + CreateRegisterUserNode(log_out_node_id_)}; + RpcDataFlowGraph graph = RpcDataFlowGraph::FromSequence(sequence); + EXPECT_EQ(graph.GetSequence(), sequence); +} + +TEST_F(RpcDataFlowGraphTest, + GetSequenceReturnTheSameRandomizedTopoSortedSequence) { + RpcDataFlowGraph graph; + graph.AddNode(0, CreateLogInUserNode(log_in_node_id_)); + graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_)); + absl::string_view method_name = graph.GetSequence()[0].method().name(); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(method_name, graph.GetSequence()[0].method().name()); + } +} + +TEST_F( + RpcDataFlowGraphTest, + GetSequenceReturnDifferentRandomizedTopoSortedSequenceAfterModification) { + const std::string_view expected_methods_of_first_nodes[] = {"LogInUser", + "RegisterUser"}; + absl::flat_hash_set<std::string> methods_of_first_nodes; + RpcDataFlowGraph graph; + graph.AddNode(0, CreateLogInUserNode(log_in_node_id_)); + graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_)); + for (int i = 0; i < 100; ++i) { + graph.RemoveNode(1); + graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_)); + methods_of_first_nodes.insert(graph.GetSequence()[0].method().name()); + } + EXPECT_THAT(methods_of_first_nodes, + UnorderedElementsAreArray(expected_methods_of_first_nodes)); +} + +TEST_F(RpcDataFlowGraphTest, + GetRandomSequenceAlwaysReturnRandomizedTopoSortedSequence) { + const std::string_view expected_methods_of_first_nodes[] = {"LogInUser", + "RegisterUser"}; + absl::flat_hash_set<std::string> methods_of_first_nodes; + RpcDataFlowGraph graph; + graph.AddNode(0, CreateLogInUserNode(log_in_node_id_)); + graph.AddNode(1, CreateRegisterUserNode(log_out_node_id_)); + for (int i = 0; i < 100; ++i) { + graph.RandomizeTopologicalOrdering(); + methods_of_first_nodes.insert(graph.GetSequence()[0].method().name()); + } + + EXPECT_THAT(methods_of_first_nodes, + UnorderedElementsAreArray(expected_methods_of_first_nodes)); +} + +} // namespace + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_session.h b/rpc_fuzzing/rpc_session.h new file mode 100644 index 0000000..588ace8 --- /dev/null +++ b/rpc_fuzzing/rpc_session.h
@@ -0,0 +1,608 @@ +#ifndef FUZZTEST_RPC_FUZZING_RPC_SESSION_H_ +#define FUZZTEST_RPC_FUZZING_RPC_SESSION_H_ + +#include <algorithm> +#include <memory> +#include <optional> +#include <type_traits> +#include <utility> +#include <vector> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "./fuzztest/internal/domains/domain_base.h" +#include "./fuzztest/internal/domains/protobuf_domain_impl.h" +#include "./fuzztest/internal/domains/value_mutation_helpers.h" +#include "./fuzztest/internal/logging.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/table_of_recent_compares.h" +#include "./fuzztest/internal/type_support.h" +#include "./rpc_fuzzing/proto_field_path.h" +#include "./rpc_fuzzing/rpc_executor.h" +#include "./rpc_fuzzing/rpc_potential_dfg.h" +#include "./rpc_fuzzing/rpc_sequence.h" + +namespace fuzztest { + +namespace internal { + +// A helper domain that allows its inner domain to be lazily initialized. +template <typename DomainT, typename... Args> +class Lazy : public DomainBase<Lazy<DomainT, Args...>, value_type_t<DomainT>, + corpus_type_t<DomainT>> { + public: + using typename Lazy::DomainBase::corpus_type; + using typename Lazy::DomainBase::value_type; + + Lazy(Args&&... args) : args_(std::forward<Args>(args)...) {} + + Lazy(const Lazy& other) { + if (other.inner_ != nullptr) { + inner_ = std::make_unique<DomainT>(*other.inner_); + } + args_ = other.args_; + } + + Lazy(Lazy&& other) noexcept = default; + Lazy& operator=(Lazy&& other) = default; + corpus_type Init(absl::BitGenRef prng) { return GetInnerDomain().Init(prng); } + + void Mutate(corpus_type& corpus_value, absl::BitGenRef prng, + bool shrink_only) { + GetInnerDomain().Mutate(corpus_value, prng, shrink_only); + } + + value_type GetValue(const corpus_type& corpus_value) const { + return GetInnerDomain().GetValue(corpus_value); + } + + std::optional<corpus_type> FromValue(const value_type& v) const { + return GetInnerDomain().FromValue(v); + } + + std::optional<corpus_type> ParseCorpus(const IRObject& obj) const { + return GetInnerDomain().ParseCorpus(obj); + } + + IRObject SerializeCorpus(const corpus_type& corpus_value) const { + return GetInnerDomain().SerializeCorpus(corpus_value); + } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + return GetInnerDomain().ValidateCorpusValue(corpus_value); + } + + auto GetPrinter() const { return GetInnerDomain().GetPrinter(); } + + private: + template <std::size_t... Is> + std::unique_ptr<DomainT> CreateInner(std::tuple<Args...> tup, + std::index_sequence<Is...>) const { + return std::make_unique<DomainT>(std::move(std::get<Is>(tup))...); + } + + const DomainT& GetInnerDomain() const { + if (inner_ == nullptr) { + inner_ = CreateInner(args_, std::index_sequence_for<Args...>{}); + } + return *inner_; + } + + DomainT& GetInnerDomain() { + if (inner_ == nullptr) { + inner_ = CreateInner(args_, std::index_sequence_for<Args...>{}); + } + return *inner_; + } + mutable std::unique_ptr<DomainT> inner_ = nullptr; + mutable std::tuple<Args...> args_; +}; + +template <class T> +const T& PickRandomElement(const std::vector<T>& vec, absl::BitGenRef prng) { + FUZZTEST_INTERNAL_CHECK(!vec.empty(), + "Cannot pick elements from an empy vector."); + return vec[absl::Uniform(prng, size_t{0}, vec.size())]; +} + +// Rpc session domain. It generates RPC sequences with data-flows between +// requests and responses of calls. The provided RpcSequence can be consumed by +// the RpcExecutor. +template <typename ServiceT = void> +class RpcSessionImpl : public DomainBase<RpcSessionImpl<ServiceT>, RpcSequence, + RpcDataFlowGraph> { + public: + using typename RpcSessionImpl::DomainBase::corpus_type; + using typename RpcSessionImpl::DomainBase::value_type; + + // `service_factory` should return the fully qualified service name such as + // `package.Service`. + RpcSessionImpl(std::function<absl::string_view()> service_factory) + : abstract_dfg_(CreatePotentialDfgByServiceName(service_factory())) { + const google::protobuf::ServiceDescriptor& desc = + GetServiceDescriptorByServiceName(service_factory()); + Initialize(desc); + } + + template <typename T = ServiceT, + typename = std::enable_if_t<is_stubby_service<T>::value || + is_grpc_service<T>::value || + std::is_same_v<T, google::protobuf::Service>>> + RpcSessionImpl() : abstract_dfg_(CreatePotentialDfg<ServiceT>()) { + const google::protobuf::ServiceDescriptor& desc = GetServiceDescriptor<ServiceT>(); + Initialize(desc); + } + + // Generate a sequence containing only a single rpc call. Such calls should be + // the "roots" in the data flow graph and depend on no other nodes. + corpus_type Init(absl::BitGenRef prng) { + const google::protobuf::MethodDescriptor& method = + *PickRandomElement(all_methods_, prng); + RpcDataFlowGraph result; + auto& request_domain = GetRequestDomain(method); + RpcNode call_node(method, + request_domain.GetValue(request_domain.Init(prng))); + result.AddNode(0 /* the first node */, std::move(call_node)); + return result; + } + + void Mutate(corpus_type& graph, absl::BitGenRef prng, bool only_shrink) { + if (only_shrink) { + RunOne( + prng, [&] { DeleteTailCall(graph, prng); }, + [&] { MutateStaticField(graph, prng, only_shrink); }); + } else { + RunOne( + prng, [&] { InsertTailCall(graph, prng); }, + [&] { DeleteTailCall(graph, prng); }, + [&] { MutateStaticField(graph, prng, only_shrink); }); + } + } + + value_type GetValue(const corpus_type& graph) const { + return graph.GetSequence(); + } + + std::optional<corpus_type> FromValue(const value_type& v) const { + corpus_type result; + for (size_t i = 0; i < v.size(); ++i) { + result.AddNode(i, v[i]); + } + return result; + } + + std::optional<corpus_type> ParseCorpus(const IRObject& obj) const { + auto subs = obj.Subs(); + if (!subs || subs->empty()) { + return std::nullopt; + } + RpcSequence sequence; + sequence.reserve(subs->size()); + for (const auto& node_obj : *subs) { + if (!ValidateRpcNodeObjFormat(node_obj)) return std::nullopt; + auto node = ParseRpcNodeMethodAndRequest(node_obj); + if (!node) return std::nullopt; + auto edges_sub = (*node_obj.Subs())[2].Subs(); + for (const auto& edge_obj : *edges_sub) { + std::optional<RpcDataFlowEdge> edge = + ParseRpcDataFlowEdge(edge_obj, sequence, node->method()); + if (!edge) return std::nullopt; + node->AddDependency(*edge); + } + sequence.push_back(*std::move(node)); + } + return RpcDataFlowGraph::FromSequence(sequence); + } + + // Serializes the graph in the topological-sorting order. + IRObject SerializeCorpus(const corpus_type& graph) const { + IRObject result; + std::vector<IRObject>& result_subs = result.MutableSubs(); + // We serialize the sequence instead to avoid storing the RpcNodeID. + for (const RpcNode& node : graph.GetSequence()) { + result_subs.push_back(SerializeRpcNode(node)); + } + return result; + } + + // TODO(changochen): Implemented in later CL. + auto GetPrinter() const { return MonostatePrinter{}; } + + // Check whether every RpcNode: + // 1. Only depend on previous nodes. (No cyclic dependencies) + // 2. Every dependency matches one of the potential dependencies. + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + absl::flat_hash_set<RpcNodeID> previous_node_ids; + for (RpcNodeID node_id : corpus_value.GetOrderedNodeIds()) { + const RpcNode& node = corpus_value.GetNode(node_id); + const RpcPotentialDfgNode& dfg_node = + abstract_dfg_.GetNode(node.method()); + absl::flat_hash_set<FieldPath> all_sink_paths; + for (const RpcDataFlowEdge& edge : node.dependencies()) { + // The current node should only depend on previous nodes. + if (!previous_node_ids.contains(edge.from_node_id)) + return absl::InvalidArgumentError( + "The dependencies should only come from previously executed " + "nodes."); + if (!dfg_node.GetAllDependencies().contains(edge.to_field)) { + return absl::InvalidArgumentError( + "The sink field should be defined in the potential data flow " + "graph."); + } + + if (all_sink_paths.contains(edge.to_field)) { + return absl::InvalidArgumentError( + "One sink field should have at most one concrete dependency!"); + } + all_sink_paths.insert(edge.to_field); + + // Check whether the dependency source from `edge` actually matches one + // of the potential dependency source. + if (!HasPotentialSourceForEdge(edge, corpus_value, dfg_node)) { + return absl::InvalidArgumentError( + "The dependency is not defined in the potential data flow " + "graph."); + } + } + + previous_node_ids.insert(node_id); + } + return absl::OkStatus(); + } + + private: + void Initialize(const google::protobuf::ServiceDescriptor& desc) { + for (int i = 0; i < desc.method_count(); ++i) { + const google::protobuf::MethodDescriptor* method = desc.method(i); + all_methods_.push_back(method); + + std::function<const google::protobuf::Message*()> get_request_prototype = + [method]() { + return google::protobuf::MessageFactory::generated_factory()->GetPrototype( + method->input_type()); + }; + request_domains_.emplace( + method, ProtobufDomainUntypedImpl<google::protobuf::Message>( + PrototypePtr<google::protobuf::Message>(get_request_prototype), + /*use_lazy_initialization=*/false)); + } + } + + bool AllPotentialDepsInDfgNodeCanBeSatisfied( + const absl::flat_hash_set<FieldPath>& satisfied_field_paths, + const RpcPotentialDfgNode& dfg_node) const { + for (const auto& [field_path, unused] : dfg_node.GetAllDependencies()) { + if (!satisfied_field_paths.contains(field_path)) { + bool can_be_satisfied = false; + for (const FieldPath& satisfied_field_path : satisfied_field_paths) { + if (AreOneOfAltearnatives(satisfied_field_path, field_path)) { + can_be_satisfied = true; + break; + } + } + if (!can_be_satisfied) return false; + } + } + return true; + } + // Check whether the source of the edge is defined in the potential data flow + // graph. + bool HasPotentialSourceForEdge(const RpcDataFlowEdge& edge, + const RpcDataFlowGraph& graph, + const RpcPotentialDfgNode& dfg_node) const { + for (const RpcPotentialDfgNode::PotentialDependencySource& + potential_source : dfg_node.GetDependencies(edge.to_field)) { + const RpcNode& from_node = graph.GetNode(edge.from_node_id); + if (potential_source.method == &from_node.method() && + potential_source.field_path == edge.from_field) { + return true; + } + } + return false; + } + + bool ValidateRpcNodeObjFormat(const IRObject& obj) const { + auto node_subs = obj.Subs(); + // Each node has a sub of size 3: method, request, dependencies. + if (!node_subs || node_subs->size() != 3) { + return false; + } + auto edge_subs = (*node_subs)[2].Subs(); + if (!edge_subs) { + return false; + } + // Each edge has a sub of size 3: from_id, from_field, to_field. + for (const auto& edge_sub : *edge_subs) { + if (!edge_sub.Subs() || edge_sub.Subs()->size() != 3) { + return false; + } + } + return true; + } + + IRObject SerializeRpcDataFlowEdge(const RpcDataFlowEdge& edge) const { + IRObject edge_obj; + std::vector<IRObject>& edge_subs = edge_obj.MutableSubs(); + edge_subs.push_back(IRObject::FromCorpus(edge.from_node_id)); + edge_subs.push_back(IRObject::FromCorpus(edge.from_field.ToString())); + edge_subs.push_back(IRObject::FromCorpus(edge.to_field.ToString())); + return edge_obj; + } + + IRObject SerializeRpcNode(const RpcNode& node) const { + // The serialized object will be like + // [ method_name, request_proto_buf_string, [edges]] + IRObject node_obj; + std::vector<IRObject>& node_subs = node_obj.MutableSubs(); + node_subs.push_back(IRObject::FromCorpus(node.method().full_name())); + node_subs.push_back( + IRObject::FromCorpus(node.request().SerializeAsString())); + IRObject edges_obj; + if (!node.dependencies().empty()) { + std::vector<IRObject>& edges_subs = edges_obj.MutableSubs(); + for (const RpcDataFlowEdge& edge : node.dependencies()) { + edges_subs.push_back(SerializeRpcDataFlowEdge(edge)); + } + } + node_subs.push_back(std::move(edges_obj)); + return node_obj; + } + + std::optional<RpcDataFlowEdge> ParseRpcDataFlowEdge( + const IRObject& edge_obj, const RpcSequence& sequence, + const google::protobuf::MethodDescriptor& to_method) const { + auto edge_subs = edge_obj.Subs(); + FUZZTEST_INTERNAL_CHECK(edge_subs && edge_subs->size() == 3, + "Invalid edge format!"); + auto from_node_id = (*edge_subs)[0].ToCorpus<RpcNodeID>(); + auto from_field_str = (*edge_subs)[1].ToCorpus<std::string>(); + auto to_field_str = (*edge_subs)[2].ToCorpus<std::string>(); + if (!from_node_id || !from_field_str || !to_field_str) { + return std::nullopt; + } + if (*from_node_id >= sequence.size()) return std::nullopt; + FieldPath from_field_path = GetFieldPathWithDescriptor( + *sequence[*from_node_id].method().output_type(), *from_field_str); + FieldPath to_field_path = + GetFieldPathWithDescriptor(*to_method.input_type(), *to_field_str); + return RpcDataFlowEdge{*from_node_id, from_field_path, to_field_path}; + } + + std::optional<RpcNode> ParseRpcNodeMethodAndRequest( + const IRObject& obj) const { + const google::protobuf::DescriptorPool& pool = + *google::protobuf::DescriptorPool::generated_pool(); + auto node_subs = obj.Subs(); + std::optional<std::string> method_name = + (*node_subs)[0].ToCorpus<std::string>(); + if (!method_name) return std::nullopt; + const google::protobuf::MethodDescriptor* method = + pool.FindMethodByName(*method_name); + + if (!method) return std::nullopt; + + std::optional<std::string> request_str = + (*node_subs)[1].ToCorpus<std::string>(); + + std::unique_ptr<google::protobuf::Message> request = + absl::WrapUnique(google::protobuf::MessageFactory::generated_factory() + ->GetPrototype(method->input_type()) + ->New()); + if (!request->ParseFromString(*request_str)) return std::nullopt; + + return RpcNode(*method, std::move(request)); + } + + absl::flat_hash_set<FieldPath> CollectSatisfiableFields( + const RpcPotentialDfgNode& dfg_node, + const absl::flat_hash_map<const google::protobuf::MethodDescriptor*, + std::vector<RpcNodeID>>& existing_nodes) { + absl::flat_hash_set<FieldPath> satisfied_fields; + for (const auto& [field_path, source_vec] : dfg_node.GetAllDependencies()) { + for (const RpcPotentialDfgNode::PotentialDependencySource& source : + source_vec) { + if (existing_nodes.contains(source.method)) { + satisfied_fields.insert(field_path); + break; + } + } + } + return satisfied_fields; + } + + std::vector<FieldPath> RemoveAlternativeFieldPath( + const absl::flat_hash_set<FieldPath>& field_paths, absl::BitGenRef prng) { + std::vector<FieldPath> result; + std::vector<FieldPath> field_paths_vec(field_paths.begin(), + field_paths.end()); + std::shuffle(field_paths_vec.begin(), field_paths_vec.end(), prng); + for (const FieldPath& field_path : field_paths_vec) { + bool found_same_one_of_field = false; + for (const FieldPath& other_field_path : result) { + if (AreOneOfAltearnatives(field_path, other_field_path)) { + found_same_one_of_field = true; + break; + } + } + if (!found_same_one_of_field) { + result.push_back(field_path); + } + } + return result; + } + + // Checks whether all the dependencies of `dfg_node` can be satisfied by the + // nodes in `existing_nodes`. "Satisfy" means there is at least one + // potential dependency of any field in the node that has its source in the + // rpc data flow graph. + bool DependsOnAndCanBeSatisfiedBy( + const RpcPotentialDfgNode& dfg_node, + const absl::flat_hash_map<const google::protobuf::MethodDescriptor*, + std::vector<RpcNodeID>>& existing_nodes) { + if (!dfg_node.HasDependency()) return false; + absl::flat_hash_set<FieldPath> satisfied_fields = + CollectSatisfiableFields(dfg_node, existing_nodes); + // As long as we can satisfy any of the sink field, we consider it as + // satisfiable. + return !satisfied_fields.empty(); + } + + // Randomly selects a potential dependency for each dynamic field in `node` + // and establishes the dependency between the node and the dependency source. + // As we might have multiple calls to the same method in the existing graph + // that can serve as the source, we randomly select one of them. + void ConnectToExistingNodes( + const absl::flat_hash_map<const google::protobuf::MethodDescriptor*, + std::vector<RpcNodeID>> + existing_nodes, + RpcNode& node, absl::BitGenRef prng) { + const RpcPotentialDfgNode& dfg_node = abstract_dfg_.GetNode(node.method()); + + // If we have multiple satisfiable field paths that are alternative to each + // other (e.g., in the same `oneof`), we only need to pick one of them and + // establish the concrete dependency. + std::vector<FieldPath> satisfied_fields = RemoveAlternativeFieldPath( + CollectSatisfiableFields(dfg_node, existing_nodes), prng); + FUZZTEST_INTERNAL_CHECK(!satisfied_fields.empty(), "Impossible"); + for (const FieldPath& sink_field : satisfied_fields) { + const auto& source_vec = dfg_node.GetDependencies(sink_field); + std::vector<const RpcPotentialDfgNode::PotentialDependencySource*> + satisfiable_sources; + for (const RpcPotentialDfgNode::PotentialDependencySource& source : + source_vec) { + if (existing_nodes.contains(source.method)) + satisfiable_sources.push_back(&source); + } + if (satisfiable_sources.empty()) continue; + FUZZTEST_INTERNAL_CHECK(!satisfiable_sources.empty(), + "Some dependencies are not satisified!"); + const RpcPotentialDfgNode::PotentialDependencySource& chosen_source = + *PickRandomElement(satisfiable_sources, prng); + const std::vector<RpcNodeID>& candidate_nodes = + existing_nodes.find(chosen_source.method)->second; + node.AddDependency( + RpcDataFlowEdge{PickRandomElement(candidate_nodes, prng), + chosen_source.field_path, sink_field}); + } + } + + // Inserts a tail call to the graph. Returns true if a call is inserted. + bool InsertTailCall(corpus_type& graph, absl::BitGenRef prng) { + absl::flat_hash_map<const google::protobuf::MethodDescriptor*, std::vector<RpcNodeID>> + existing_nodes; + RpcNodeID max_node_id = 0; + for (const auto& [node_id, node] : graph.GetAllNodes()) { + existing_nodes[&node.method()].push_back(node_id); + max_node_id = std::max(max_node_id, node_id); + } + RpcNodeID new_node_id = max_node_id + 1; + + // Collect all the methods that can be inserted: They depend on existing + // nodes and all their their dependencies can be satisfied by existing + // nodes. + std::vector<const google::protobuf::MethodDescriptor*> candidate_methods; + for (auto& [method, unused_domain] : request_domains_) { + const RpcPotentialDfgNode& dfg_node = abstract_dfg_.GetNode(*method); + if (DependsOnAndCanBeSatisfiedBy(dfg_node, existing_nodes)) { + candidate_methods.push_back(method); + } + } + + if (candidate_methods.empty()) return false; + + const google::protobuf::MethodDescriptor& method_to_insert = + *PickRandomElement(candidate_methods, prng); + auto& request_domain = GetRequestDomain(method_to_insert); + RpcNode node_to_insert(method_to_insert, + request_domain.GetValue(request_domain.Init(prng))); + ConnectToExistingNodes(existing_nodes, node_to_insert, prng); + graph.AddNode(new_node_id, std::move(node_to_insert)); + return true; + } + + bool DeleteTailCall(corpus_type& graph, absl::BitGenRef prng) { + if (graph.NodeNum() == 1) return false; + // Non-tail nodes (nodes with dependents). + absl::flat_hash_set<RpcNodeID> non_tail_nodes; + for (const auto& [node_id, node] : graph.GetAllNodes()) { + for (const auto& edge : node.dependencies()) { + non_tail_nodes.insert(edge.from_node_id); + } + } + FUZZTEST_INTERNAL_CHECK(non_tail_nodes.size() != graph.NodeNum(), + "Loop dependence!"); + std::vector<RpcNodeID> tail_nodes; + for (const auto& [node_id, node] : graph.GetAllNodes()) { + if (!non_tail_nodes.contains(node_id)) tail_nodes.push_back(node_id); + } + + graph.RemoveNode(PickRandomElement(tail_nodes, prng)); + return true; + } + + bool MutateStaticField(corpus_type& graph, absl::BitGenRef prng, + bool only_shrink) { + RpcNode& selected_node = std::next(graph.GetAllNodes().begin(), + ChooseOffset(graph.NodeNum(), prng)) + ->second; + auto& request_domain = GetRequestDomain(selected_node.method()); + auto request_domain_corpus_value = + request_domain.FromValue(selected_node.request()); + FUZZTEST_INTERNAL_CHECK(request_domain_corpus_value.has_value(), + "Invalid proto!"); + request_domain.Mutate(*request_domain_corpus_value, prng, only_shrink); + selected_node.request().CopyFrom( + *request_domain.GetValue(*request_domain_corpus_value)); + return true; + } + + ProtobufDomainUntypedImpl<google::protobuf::Message>& GetRequestDomain( + const google::protobuf::MethodDescriptor& method) { + auto iter = request_domains_.find(&method); + FUZZTEST_INTERNAL_CHECK(iter != request_domains_.end(), "Invalid method!"); + return iter->second; + } + + // We keep a domain for the request type for each method in the service to + // generate the request message in the RpcNode. + // TODO(changochen): We are now directly storing the concrete requests in + // RpcNode, and serialize them as string. We should use the corpus_type of the + // proto domain in the future to allow proto customization. + absl::flat_hash_map<const google::protobuf::MethodDescriptor*, + ProtobufDomainUntypedImpl<google::protobuf::Message>> + request_domains_; + std::vector<const google::protobuf::MethodDescriptor*> all_methods_; + RpcPotentialDataFlowGraph abstract_dfg_; +}; +} // namespace internal + +using internal::RpcSequence; + +inline absl::Status ExecuteRpcSequence(RpcStub& stub, RpcSequence& sequence) { + return internal::RpcExecutor(&stub).Execute(sequence); +} + +template <typename ServiceT> +inline auto RpcSession() { + return internal::Lazy<internal::RpcSessionImpl<ServiceT>>(); +} + +// Creates an RpcSession with a service name factory. This is for when the +// service name is unknown at compile time (i.e., getting the name from the +// generated pool.). This factory function will be called after main(). +// If the service name/type is known at compile time, please use the template +// version above: RpcSession<ServiceT>(). +inline auto RpcSessionOf( + std::function<absl::string_view()> service_name_factory) { + FUZZTEST_INTERNAL_CHECK(service_name_factory, "Invalid service factory!"); + return internal::Lazy<internal::RpcSessionImpl<>, + std::function<absl::string_view()>>( + std::move(service_name_factory)); +} + +} // namespace fuzztest +#endif // FUZZTEST_RPC_FUZZING_RPC_SESSION_H_
diff --git a/rpc_fuzzing/rpc_session_test.cc b/rpc_fuzzing/rpc_session_test.cc new file mode 100644 index 0000000..3b07361 --- /dev/null +++ b/rpc_fuzzing/rpc_session_test.cc
@@ -0,0 +1,501 @@ +#include "./rpc_fuzzing/rpc_session.h" + +#include <cstdint> +#include <optional> +#include <variant> + +#include "google/protobuf/descriptor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/random.h" +#include "./domain_tests/domain_testing.h" +#include "./fuzztest/internal/serialization.h" +#include "./rpc_fuzzing/proto_field_path.h" +#include "./rpc_fuzzing/rpc_potential_dfg.h" +#include "./rpc_fuzzing/rpc_sequence.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" +#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h" + +namespace fuzztest::internal { + +namespace { + +using ::testing::AnyOf; +using ::testing::Conditional; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::FieldsAre; +using ::testing::IsTrue; +using ::testing::NanSensitiveDoubleEq; +using ::testing::Optional; +using ::testing::Pair; +using ::testing::ResultOf; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; +using ::testing::VariantWith; + +template <typename T> +auto ValueIs(const T& v) { + if constexpr (std::is_same_v<T, double>) { + return FieldsAre(VariantWith<double>(NanSensitiveDoubleEq(v))); + } else { + return FieldsAre(VariantWith<T>(v)); + } +} + +template <typename... T> +auto SubsAre(const T&... v) { + return FieldsAre(VariantWith<std::vector<IRObject>>(ElementsAre(v...))); +} + +class RpcDomainTest : public ::testing::Test { + public: + RpcDomainTest() + : pool_(*ABSL_DIE_IF_NULL(google::protobuf::DescriptorPool::generated_pool())), + mini_blogger_serivce_(*ABSL_DIE_IF_NULL( + pool_.FindServiceByName("fuzztest.internal.MiniBlogger"))), + log_in_user_method_(*ABSL_DIE_IF_NULL( + mini_blogger_serivce_.FindMethodByName("LogInUser"))), + log_out_user_method_(*ABSL_DIE_IF_NULL( + mini_blogger_serivce_.FindMethodByName("LogOutUser"))), + get_user_posts_method_(*ABSL_DIE_IF_NULL( + mini_blogger_serivce_.FindMethodByName("GetUserPosts"))) {} + + protected: + RpcNode GetLogInUserNode() const { + return RpcNode(log_in_user_method_, std::make_unique<LogInUserRequest>()); + } + + RpcNode GetLogOutUserNode() const { + return RpcNode(log_out_user_method_, std::make_unique<LogOutUserRequest>()); + } + + RpcNode GetGetUserPostsNode() const { + return RpcNode(get_user_posts_method_, + std::make_unique<GetUserPostsRequest>()); + } + + RpcDataFlowGraph CreateRandomGraph(RpcSessionImpl<MiniBlogger>& rpc_domain) { + absl::BitGen gen; + RpcDataFlowGraph graph = rpc_domain.Init(gen); + for (int i = 0; i < 100; ++i) { + rpc_domain.Mutate(graph, gen, false); + } + // Generate an order. + (void)graph.GetOrderedNodeIds(); + return graph; + } + + const google::protobuf::DescriptorPool& pool_; + const google::protobuf::ServiceDescriptor& mini_blogger_serivce_; + const google::protobuf::MethodDescriptor& log_in_user_method_; + const google::protobuf::MethodDescriptor& log_out_user_method_; + const google::protobuf::MethodDescriptor& get_user_posts_method_; +}; + +TEST_F(RpcDomainTest, InitGeneratesARandomSingleCallWithoutDependencies) { + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + for (int i = 0; i < 100; ++i) { + RpcDataFlowGraph value = rpc_domain.Init(bitgen); + EXPECT_THAT(value.GetAllNodes(), + ElementsAre(Pair(0, ResultOf( + [](const RpcNode& node) { + return node.dependencies().empty(); + }, + IsTrue())))); + } +} + +TEST_F(RpcDomainTest, MutationOnlyInsertsNodesThatDependOnExistingNodes) { + constexpr std::string_view kMethodsDependingLogInUser[] = {"GetUserPosts", + "LogOutUser"}; + + RpcDataFlowGraph graph; + graph.AddNode(0, GetLogInUserNode()); + absl::flat_hash_set<std::string> inserted_methods; + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + for (int j = 0; j < 100; ++j) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, false); + if (mutated_graph.NodeNum() != 2) continue; + // Mutated through insertion. + inserted_methods.insert(mutated_graph.GetSequence()[1].method().name()); + } + EXPECT_THAT(inserted_methods, + UnorderedElementsAreArray(kMethodsDependingLogInUser)); +} + +TEST_F(RpcDomainTest, InsertedNodesHaveDependenciesOnExistingNodes) { + constexpr RpcNodeID kFromNodeID = 0; + RpcDataFlowGraph graph; + graph.AddNode(kFromNodeID, GetLogInUserNode()); + + bool insertion_triggerred = false; + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + for (int j = 0; j < 100; ++j) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, false); + if (mutated_graph.NodeNum() != 2) continue; + // Mutated through insertion. + insertion_triggerred = true; + + const RpcNode inserted_node = mutated_graph.GetSequence()[1]; + EXPECT_THAT( + inserted_node.dependencies(), + ElementsAre(FieldsAre( + kFromNodeID, GetFieldPath<LogInUserResponse>("session_id"), + Conditional(inserted_node.method().name() == "LogOutUser", + AnyOf(GetFieldPath<LogOutUserRequest>( + "log_out_info.session_info.session_id"), + GetFieldPath<LogOutUserRequest>( + "log_out_info.session_id")), + GetFieldPath<GetUserPostsRequest>("session_id"))))); + } + + EXPECT_TRUE(insertion_triggerred); +} + +TEST_F(RpcDomainTest, InsertedNodeSelectsRandomAlternativeDepWithinSameOneOf) { + constexpr RpcNodeID kFromNodeID = 0; + RpcDataFlowGraph graph; + graph.AddNode(kFromNodeID, GetLogInUserNode()); + + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::flat_hash_set<std::string> sink_fields; + absl::BitGen bitgen; + for (int j = 0; j < 200; ++j) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, false); + if (mutated_graph.NodeNum() != 2 || + mutated_graph.GetSequence()[1].method().name() != "LogOutUser") + continue; + // Either `log_out_info.session_info.session_id` or + // `log_out_info.session_id` will be selected as the sink. + EXPECT_EQ(mutated_graph.GetSequence()[1].dependencies().size(), 1); + sink_fields.insert( + mutated_graph.GetSequence()[1].dependencies()[0].to_field.ToString()); + } + EXPECT_THAT(sink_fields, + UnorderedElementsAre("log_out_info.session_info.session_id", + "log_out_info.session_id")); +} + +TEST_F(RpcDomainTest, MutationOnlyDeletesTailNodes) { + constexpr std::string_view kMethodsDependingLogInUser[] = {"GetUserPosts", + "LogOutUser"}; + // Node id 0. + RpcNode log_in_user_node = GetLogInUserNode(); + // Node id 1. + RpcNode log_out_user_node = GetLogOutUserNode(); + // Node id 2. + RpcNode get_user_posts_node = GetGetUserPostsNode(); + + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + get_user_posts_node.AddDependency( + RpcDataFlowEdge{0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<GetUserPostsRequest>("session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + graph.AddNode(2, get_user_posts_node); + + absl::flat_hash_set<std::string> deleted_methods; + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + + for (int j = 0; j < 100; ++j) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, false); + if (mutated_graph.NodeNum() != 2) continue; + // Mutated through deletion. + if (mutated_graph.GetSequence()[1].method().name() == "LogOutUser") { + deleted_methods.insert("GetUserPosts"); + } else { + deleted_methods.insert("LogOutUser"); + } + } + EXPECT_THAT(deleted_methods, + UnorderedElementsAreArray(kMethodsDependingLogInUser)); +} + +TEST_F(RpcDomainTest, OnlyShrinkMutationDecreasesNodeNumOrRequest) { + RpcNode log_in_user_node = GetLogInUserNode(); + auto get_user_post_request = std::make_unique<GetUserPostsRequest>(); + get_user_post_request->set_max_posts(100); + RpcNode log_out_user_node(get_user_posts_method_, + std::move(get_user_post_request)); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + bool static_field_shrinked = false; + for (int i = 0; i < 100; ++i) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, true); + + if (mutated_graph.NodeNum() == 2) { + const RpcNode& node = mutated_graph.GetNode(1); + int new_max_posts = node.request().GetReflection()->GetInt32( + node.request(), + node.request().GetDescriptor()->FindFieldByName("max_posts")); + EXPECT_LE(new_max_posts, 100); + if (new_max_posts < 100) { + static_field_shrinked = true; + } + } else { + EXPECT_EQ(mutated_graph.NodeNum(), 1); + } + } + EXPECT_TRUE(static_field_shrinked); +} + +TEST_F(RpcDomainTest, MutateEventuallyChangesStaticFieldInRequest) { + RpcDataFlowGraph graph; + graph.AddNode(0, GetLogInUserNode()); + + RpcSessionImpl<MiniBlogger> rpc_domain; + absl::BitGen bitgen; + bool static_field_changed = false; + for (int i = 0; i < 100; ++i) { + RpcDataFlowGraph mutated_graph = graph; + rpc_domain.Mutate(mutated_graph, bitgen, false); + if (mutated_graph.NodeNum() == graph.NodeNum() && + mutated_graph.GetNode(0).method().name() == + graph.GetNode(0).method().name()) { + // Mutated through static field mutation. + if (!google::protobuf::util::MessageDifferencer::Equals( + mutated_graph.GetNode(0).request(), graph.GetNode(0).request())) { + static_field_changed = true; + break; + } + } + } + + EXPECT_TRUE(static_field_changed); +} + +TEST_F(RpcDomainTest, FromValueTransformsSequenceToRpcGraph) { + RpcSequence sequence = {GetLogInUserNode(), GetLogOutUserNode(), + GetGetUserPostsNode()}; + + RpcSessionImpl<MiniBlogger> rpc_domain; + std::optional<RpcDataFlowGraph> graph = rpc_domain.FromValue(sequence); + + ASSERT_TRUE(graph.has_value()); + + ASSERT_EQ(graph->NodeNum(), 3); + std::vector<RpcNode> all_nodes{graph->GetNode(0), graph->GetNode(1), + graph->GetNode(2)}; + EXPECT_EQ(all_nodes, sequence); +} + +TEST_F(RpcDomainTest, ParseCorpusInClearTextFromReturnsRpcDataFlowGraph) { + constexpr absl::string_view kObjectText = + R"(FUZZTESTv1 + sub { + sub { s: "fuzztest.internal.MiniBlogger.LogInUser" } + sub { s: "" } + sub { } + } + sub { + sub { s: "fuzztest.internal.MiniBlogger.LogOutUser" } + sub { s: "" } + sub { + sub { + sub { i: 0 } + sub { s: "session_id" } + sub { s: "log_out_info.session_info.session_id" } + } + } + })"; + std::optional<IRObject> obj = IRObject::FromString(kObjectText); + ASSERT_TRUE(obj.has_value()); + + RpcNode log_in_user_node = GetLogInUserNode(); + RpcNode log_out_user_node = GetLogOutUserNode(); + RpcDataFlowEdge edge = { + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}; + log_out_user_node.AddDependency(edge); + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + RpcSessionImpl<MiniBlogger> rpc_domain; + auto obj2 = rpc_domain.SerializeCorpus(graph); + EXPECT_EQ(obj->ToString(), obj2.ToString()); +} + +TEST_F(RpcDomainTest, ParseCorpusReturnsRpcDataFlowGraph) { + RpcNode log_in_user_node = GetLogInUserNode(); + RpcNode log_out_user_node = GetLogOutUserNode(); + RpcDataFlowEdge edge = { + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}; + log_out_user_node.AddDependency(edge); + + /* Set up log_in_user_node */ + IRObject log_in_user_node_obj; + auto& log_in_user_node_subs = log_in_user_node_obj.MutableSubs(); + log_in_user_node_subs.push_back( + IRObject::FromCorpus(log_in_user_node.method().full_name())); + log_in_user_node_subs.push_back( + IRObject::FromCorpus(log_in_user_node.request().SerializeAsString())); + log_in_user_node_subs.push_back(IRObject{}); + + /* Set up log_out_user_node*/ + IRObject log_out_user_node_obj; + auto& log_out_user_node_subs = log_out_user_node_obj.MutableSubs(); + log_out_user_node_subs.push_back( + IRObject::FromCorpus(log_out_user_node.method().full_name())); + log_out_user_node_subs.push_back( + IRObject::FromCorpus(log_out_user_node.request().SerializeAsString())); + IRObject edge_obj; + auto& edge_subs = edge_obj.MutableSubs(); + edge_subs.push_back(IRObject::FromCorpus(edge.from_node_id)); + edge_subs.push_back(IRObject::FromCorpus(edge.from_field.ToString())); + edge_subs.push_back(IRObject::FromCorpus(edge.to_field.ToString())); + IRObject edges; + edges.MutableSubs().push_back(edge_obj); + log_out_user_node_subs.push_back(edges); + + IRObject obj; + auto& subs = obj.MutableSubs(); + subs.push_back(log_in_user_node_obj); + subs.push_back(log_out_user_node_obj); + + RpcSessionImpl<MiniBlogger> rpc_domain; + auto graph = rpc_domain.ParseCorpus(obj); + ASSERT_TRUE(graph.has_value()); + EXPECT_EQ(graph->NodeNum(), 2); + EXPECT_EQ(graph->GetNode(0), log_in_user_node); + EXPECT_EQ(graph->GetNode(1), log_out_user_node); +} + +TEST_F(RpcDomainTest, SerializesCorpusReturnsIRObjectOfSpecificStructure) { + RpcNode log_in_user_node = GetLogInUserNode(); + RpcNode log_out_user_node = GetLogOutUserNode(); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + IRObject obj = rpc_domain.SerializeCorpus(graph); + + EXPECT_THAT( + obj, + SubsAre( + /*log_in_user_node*/ + SubsAre(ValueIs<std::string>(log_in_user_method_.full_name()), + ValueIs<std::string>( + log_in_user_node.request().SerializeAsString()), + ValueIs<std::monostate>({})), + /*log_out_user_node*/ + SubsAre( + ValueIs<std::string>(log_out_user_method_.full_name()), + ValueIs<std::string>( + log_out_user_node.request().SerializeAsString()), + /*dependencies*/ + SubsAre(SubsAre( + ValueIs<std::uint64_t>(0) /*from_node_id*/, + ValueIs<std::string>("session_id") /*from_field*/, + ValueIs< + std::string>(/*to_field*/ + "log_out_info.session_info.session_id")))))); +} + +TEST_F(RpcDomainTest, SerializesCorpusAndParsesCorpusReturnTheSameObject) { + RpcSessionImpl<MiniBlogger> rpc_domain; + for (int i = 0; i < 100; ++i) { + RpcDataFlowGraph graph = CreateRandomGraph(rpc_domain); + EXPECT_THAT(rpc_domain.ParseCorpus(rpc_domain.SerializeCorpus(graph)), + Optional(Eq(graph))); + } +} + +TEST_F(RpcDomainTest, ValidRpcNodeShouldOnlyDependOnPreviousNodes) { + RpcNode log_out_user_node = GetLogOutUserNode(); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + EXPECT_THAT( + rpc_domain.ValidateCorpusValue(graph), + IsInvalid( + "The dependencies should only come from previously executed nodes.")); +} + +TEST_F(RpcDomainTest, ValidRpcNodeDependencyMatchesPotentialDependency) { + RpcNode log_out_user_node = GetLogOutUserNode(); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<RegisterUserResponse>("success"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, GetLogInUserNode()); + graph.AddNode(1, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + EXPECT_THAT( + rpc_domain.ValidateCorpusValue(graph), + IsInvalid( + "The dependency is not defined in the potential data flow graph.")); +} + +TEST_F(RpcDomainTest, + ValidRpcNodeShouldHaveAtMostOneDependencyForEachDynamicField) { + RpcNode log_in_user_node = GetLogInUserNode(); + RpcNode log_out_user_node = GetLogOutUserNode(); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + EXPECT_THAT( + rpc_domain.ValidateCorpusValue(graph), + IsInvalid("One sink field should have at most one concrete dependency!")); +} + +TEST_F(RpcDomainTest, ValidRpcNodePassValidationTest) { + RpcNode log_in_user_node = GetLogInUserNode(); + RpcNode log_out_user_node = GetLogOutUserNode(); + log_out_user_node.AddDependency(RpcDataFlowEdge{ + 0, GetFieldPath<LogInUserResponse>("session_id"), + GetFieldPath<LogOutUserRequest>("log_out_info.session_info.session_id")}); + + RpcDataFlowGraph graph; + graph.AddNode(0, log_in_user_node); + graph.AddNode(1, log_out_user_node); + + RpcSessionImpl<MiniBlogger> rpc_domain; + EXPECT_OK(rpc_domain.ValidateCorpusValue(graph)); +} +} // namespace + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/rpc_stub.h b/rpc_fuzzing/rpc_stub.h new file mode 100644 index 0000000..3969912 --- /dev/null +++ b/rpc_fuzzing/rpc_stub.h
@@ -0,0 +1,25 @@ +#ifndef FUZZTEST_RPC_FUZZING_RPC_STUB_H_ +#define FUZZTEST_RPC_FUZZING_RPC_STUB_H_ + +#include <memory> + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/statusor.h" + +namespace fuzztest { + +// RPC service stub interface. +class RpcStub { + public: + virtual ~RpcStub() = default; + // Calls the RPC method described by `method_descriptor`, sending it `request` + // and returning the obtained response. + virtual absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod( + const google::protobuf::MethodDescriptor& method_descriptor, + const google::protobuf::Message& request) = 0; +}; + +} // namespace fuzztest + +#endif // FUZZTEST_RPC_FUZZING_RPC_STUB_H_
diff --git a/rpc_fuzzing/scaffolding_stub.cc b/rpc_fuzzing/scaffolding_stub.cc new file mode 100644 index 0000000..b008bb7 --- /dev/null +++ b/rpc_fuzzing/scaffolding_stub.cc
@@ -0,0 +1,58 @@ +#include "./rpc_fuzzing/scaffolding_stub.h" + +#include <memory> +#include <string> + +#include "net/base/sslconstant.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "net/rpc/anonymous-stub.h" +#include "net/rpc2/rpc2.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" + +namespace fuzztest { + +absl::StatusOr<std::unique_ptr<google::protobuf::Message>> ScaffoldingStub::CallMethod( + const google::protobuf::MethodDescriptor& method_descriptor, + const google::protobuf::Message& request) { + const google::protobuf::Message* response_prototype = + google::protobuf::MessageFactory::generated_factory()->GetPrototype( + method_descriptor.output_type()); + if (!response_prototype) { + return absl::InternalError( + absl::StrCat("Cannot find prototype for ", + method_descriptor.output_type()->full_name(), + " in the generated proto MessageFactory")); + } + std::unique_ptr<google::protobuf::Message> response = + absl::WrapUnique(response_prototype->New()); + if (!response) { + return absl::InternalError( + absl::StrCat("Cannot create a new instance of response type ", + method_descriptor.output_type()->full_name())); + } + + RPC rpc; + // AnonymousStub doesn't have access to the method options, so we need to set + // this manually. + if (method_descriptor.options().has_security_level()) { + rpc.set_requested_security_level(net_base::SSLSecurityLevel( + method_descriptor.options().security_level())); + } + const google::protobuf::ServiceDescriptor& service_descriptor = + *method_descriptor.service(); + const std::string method_name = absl::StrCat("/", service_descriptor.name(), + ".", method_descriptor.name()); + stub_->Send(/*package_name=*/"", method_name, &rpc, &request, response.get(), + /* done= */ nullptr); + rpc.WaitRespectingFiberCancellation(); + if (rpc.util_status().ok()) { + return response; + } + return rpc.util_status(); +} + +} // namespace fuzztest
diff --git a/rpc_fuzzing/scaffolding_stub.h b/rpc_fuzzing/scaffolding_stub.h new file mode 100644 index 0000000..807c8d6 --- /dev/null +++ b/rpc_fuzzing/scaffolding_stub.h
@@ -0,0 +1,31 @@ +#ifndef FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_ +#define FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_ + +#include <memory> +#include <utility> + +#include "net/rpc/anonymous-stub.h" +#include "net/rpc2/rpc2.h" +#include "absl/status/statusor.h" +#include "./rpc_fuzzing/rpc_stub.h" + +namespace fuzztest { + +class ScaffoldingStub : public RpcStub { + public: + ScaffoldingStub(absl::string_view socket) + : stub_(std::make_unique<AnonymousStub>( + rpc2::CreateClientChannel(socket))) {} + ScaffoldingStub(std::unique_ptr<AnonymousStub> stub) + : stub_(std::move(stub)) {} + absl::StatusOr<std::unique_ptr<google::protobuf::Message>> CallMethod( + const google::protobuf::MethodDescriptor& method_descriptor, + const google::protobuf::Message& request) override; + + private: + std::unique_ptr<AnonymousStub> stub_; +}; + +} // namespace fuzztest + +#endif // FUZZTEST_RPC_FUZZING_SCAFFOLDING_STUB_H_
diff --git a/rpc_fuzzing/testdata/BUILD b/rpc_fuzzing/testdata/BUILD new file mode 100644 index 0000000..ba59853 --- /dev/null +++ b/rpc_fuzzing/testdata/BUILD
@@ -0,0 +1,54 @@ +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") + +package(default_visibility = [ + "@com_google_fuzztest//e2e_tests/testdata:__pkg__", + "@com_google_fuzztest//fuzztest:__subpackages__", + "@com_google_fuzztest//rpc_fuzzing:__pkg__", +]) + +proto_library( + name = "mini_blogger", + srcs = ["mini_blogger.proto"], + target_compatible_with = select({ + "@com_google_fuzztest//rpc_fuzzing:rpc_fuzzing_is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), +) + +cc_proto_library( + name = "mini_blogger_cc_proto", + deps = [":mini_blogger"], +) + +cc_grpc_library( + name = "mini_blogger_cc_grpc_proto", + srcs = [":mini_blogger"], + grpc_only = True, + deps = [":mini_blogger_cc_proto"], +) + +cc_library( + name = "mini_blogger_grpc_service", + srcs = ["grpc/mini_blogger_service.cc"], + hdrs = ["grpc/mini_blogger_service.h"], + deps = [ + ":mini_blogger_cc_grpc_proto", + ":mini_blogger_cc_proto", + ":mini_blogger_context", + "@com_github_grpc_grpc//:grpc++", + ], +) + +cc_library( + name = "mini_blogger_context", + srcs = ["mini_blogger_context.cc"], + hdrs = ["mini_blogger_context.h"], + deps = [ + ":mini_blogger_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +)
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc b/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc new file mode 100644 index 0000000..a49bd61 --- /dev/null +++ b/rpc_fuzzing/testdata/grpc/mini_blogger_server.cc
@@ -0,0 +1,36 @@ +// A grpc server that can run locally for testing. +#include <cstdint> +#include <memory> +#include <string> + +#include "base/init_google.h" +#include "absl/base/log_severity.h" +#include "absl/flags/flag.h" +#include "absl/log/globals.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h" +#include "grpcpp//security/server_credentials.h" +#include "grpcpp//server.h" +#include "grpcpp//server_builder.h" + +ABSL_FLAG(int32_t, port, 5000, "port to listen on"); + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, /*remove_flags=*/true); + + absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); + + std::string server_address = absl::StrCat("[::]:", absl::GetFlag(FLAGS_port)); + + fuzztest::internal::MiniBloggerGrpcService mini_blogger; + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&mini_blogger); + std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); + LOG(INFO) << "Server listening on " << server_address; + + server->Wait(); + + return 0; +}
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc b/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc new file mode 100644 index 0000000..d6cc705 --- /dev/null +++ b/rpc_fuzzing/testdata/grpc/mini_blogger_service.cc
@@ -0,0 +1,36 @@ +#include "./rpc_fuzzing/testdata/grpc/mini_blogger_service.h" + +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" +#include "grpcpp//server_context.h" +#include "grpcpp//support/status.h" + +namespace fuzztest::internal { + +grpc::Status MiniBloggerGrpcService::RegisterUser( + grpc::ServerContext*, const RegisterUserRequest* request, + RegisterUserResponse* response) { + mini_blogger_context_.RegisterUser(*request, *response); + return grpc::Status::OK; +} + +grpc::Status MiniBloggerGrpcService::LogInUser(grpc::ServerContext*, + const LogInUserRequest* request, + LogInUserResponse* response) { + mini_blogger_context_.LogInUser(*request, *response); + return grpc::Status::OK; +} + +grpc::Status MiniBloggerGrpcService::GetUserPosts( + grpc::ServerContext* unused_service_context, + const GetUserPostsRequest* request, GetUserPostsResponse* response) { + mini_blogger_context_.GetUserPosts(*request, *response); + return grpc::Status::OK; +} + +grpc::Status MiniBloggerGrpcService::LogOutUser( + grpc::ServerContext* unused_service_context, + const LogOutUserRequest* request, LogOutUserResponse* response) { + mini_blogger_context_.LogOutUser(*request, *response); + return grpc::Status::OK; +} +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/testdata/grpc/mini_blogger_service.h b/rpc_fuzzing/testdata/grpc/mini_blogger_service.h new file mode 100644 index 0000000..d627b37 --- /dev/null +++ b/rpc_fuzzing/testdata/grpc/mini_blogger_service.h
@@ -0,0 +1,34 @@ +#ifndef FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_ +#define FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_ + +#include "./rpc_fuzzing/testdata/mini_blogger.grpc.pb.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" +#include "./rpc_fuzzing/testdata/mini_blogger_context.h" +#include "grpcpp//support/status.h" + +namespace fuzztest::internal { + +// The gRpc server implementation of MiniBlogger. +class MiniBloggerGrpcService : public MiniBlogger::Service { + public: + grpc::Status RegisterUser(grpc::ServerContext*, + const RegisterUserRequest* request, + RegisterUserResponse* response) override; + + grpc::Status LogInUser(grpc::ServerContext*, const LogInUserRequest* request, + LogInUserResponse* response) override; + + grpc::Status GetUserPosts(grpc::ServerContext*, + const GetUserPostsRequest* request, + GetUserPostsResponse* response) override; + + grpc::Status LogOutUser(grpc::ServerContext*, + const LogOutUserRequest* request, + LogOutUserResponse* response) override; + + private: + MiniBloggerContext mini_blogger_context_; +}; +} // namespace fuzztest::internal + +#endif // FUZZTEST_RPC_FUZZING_TESTDATA_MINI_BLOGGER_GRPC_SERVICE_H_
diff --git a/rpc_fuzzing/testdata/mini_blogger.proto b/rpc_fuzzing/testdata/mini_blogger.proto new file mode 100644 index 0000000..50c2da0 --- /dev/null +++ b/rpc_fuzzing/testdata/mini_blogger.proto
@@ -0,0 +1,127 @@ +syntax = "proto3"; + +package fuzztest.internal; + +service MiniBlogger { + // Register a user. + rpc RegisterUser(RegisterUserRequest) returns (RegisterUserResponse); + // Log in a user and get a session id. + rpc LogInUser(LogInUserRequest) returns (LogInUserResponse); + // Get a user's posts by session id. + rpc GetUserPosts(GetUserPostsRequest) returns (GetUserPostsResponse); + // Log out the user (invaliding the session id). + rpc LogOutUser(LogOutUserRequest) returns (LogOutUserResponse); +} + +message RegisterUserRequest { + string user_name = 1; + string email = 2; + string password = 3; +} + +message RegisterUserResponse { + bool success = 1; +} + +message LogInUserRequest { + string name = 1; + string password = 2; +} + +message LogInUserResponse { + bool success = 1; + // A unique id for the user logging in session. + int64 session_id = 2; +} + +enum SortingOrder { + ASCENDING = 0; + DESCENDING = 1; +} + +enum SortedBy { + DATE = 0; + POPULARITY = 1; +} + +// For testing only. +message GetUserPostsOptionsWithDifferentFieldNames { + SortingOrder ordering = 1; +} + +message GetUserPostsOptions { + SortingOrder order = 1; + SortedBy sort_by = 2; +} + +message GetUserPostsRequest { + int64 session_id = 1; + int32 max_posts = 2; // Maximum number of posts returned. + GetUserPostsOptions options = 3; +} + +message GetUserPostsResponse { + repeated string posts = 1; +} + +// For testing dependency established at inner fields. +message LogOutInfo { + oneof log_out_info { + SessionInfo session_info = 1; + int64 session_id = 2; + } +} + +message SessionInfo { + int64 session_id = 1; +} + +message LogOutUserRequest { + LogOutInfo log_out_info = 1; +} + +message LogOutUserResponse { + bool success = 1; +} + +// This is only for testing that the util of collecting all fields in a message +// won't break with recursive messages. +message RecursiveNode { + int64 value = 1; + repeated RecursiveNode children = 2; +} + +// For testing only. +message OneOfMessage { + oneof oneof1 { + DummyMessage oneof1_field1 = 1; + int64 oneof1_field2 = 2; + } + oneof oneof2 { + DummyMessage oneof2_field1 = 3; + int64 oneof2_field2 = 4; + } +} + +message DummyMessage { + int64 v1 = 1; + int64 v2 = 2; +} + +message InnerRepeated { + int64 field2 = 1; +} + +message MessageWithSingleInnerRepeated { + InnerRepeated repeated_msg_field = 1; +} + +message MessageContainingRepeatedFields { + repeated int64 field = 1; + repeated InnerRepeated repeated_msg_field = 2; +} + +message MessageNotContainingRepeatedFields { + int64 field = 1; + int64 field2 = 2; +}
diff --git a/rpc_fuzzing/testdata/mini_blogger_context.cc b/rpc_fuzzing/testdata/mini_blogger_context.cc new file mode 100644 index 0000000..d46863b --- /dev/null +++ b/rpc_fuzzing/testdata/mini_blogger_context.cc
@@ -0,0 +1,102 @@ +#include "./rpc_fuzzing/testdata/mini_blogger_context.h" + +#include <cstdint> +#include <optional> +#include <string> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" + +namespace fuzztest::internal { + +void MiniBloggerContext::RegisterUser(const RegisterUserRequest& request, + RegisterUserResponse& response) { + bool result = RegisterUserHandler(request.user_name(), request.email(), + request.password()); + response.set_success(result); +} + +void MiniBloggerContext::LogInUser(const LogInUserRequest& request, + LogInUserResponse& response) { + std::optional<std::int64_t> result = + LoginUserHandler(request.name(), request.password()); + if (!result) { + response.set_success(false); + } else { + response.set_success(true); + response.set_session_id(result.value()); + } +} + +void MiniBloggerContext::GetUserPosts(const GetUserPostsRequest& request, + GetUserPostsResponse& response) { + for (const std::string& post : GetUserPostsHandler(request.session_id())) { + response.add_posts(post); + } +} + +void MiniBloggerContext::LogOutUser(const LogOutUserRequest& request, + LogOutUserResponse& response) { + if (!request.has_log_out_info() || + !request.log_out_info().has_session_info()) { + response.set_success(false); + } else { + bool result = + LogOutUserHandler(request.log_out_info().session_info().session_id()); + response.set_success(result); + } +} +bool MiniBloggerContext::RegisterUserHandler(absl::string_view name, + absl::string_view email, + absl::string_view passwd) { + if (name.empty() || email.empty() || passwd.empty()) return false; + absl::WriterMutexLock l(&lock_); + if (users_.contains(name)) return false; + users_.emplace(std::string(name), + MiniBloggerUser{std::string(name), std::string(email), + std::string(passwd)}); + return true; +} + +std::optional<std::int64_t> MiniBloggerContext::LoginUserHandler( + absl::string_view name, absl::string_view passwd) { + absl::WriterMutexLock l(&lock_); + if (!name.empty()) { + if (!users_.contains(name)) return std::nullopt; + if (users_[name].passwd != passwd) return std::nullopt; + } + absl::BitGen gen; + std::int64_t sid = + absl::Uniform<int64_t>(gen, 0, std::numeric_limits<int64_t>::max()); + active_session_ids_.insert(sid); + return sid; +} + +bool MiniBloggerContext::LogOutUserHandler(std::int64_t sid) { + absl::WriterMutexLock l(&lock_); + if (!active_session_ids_.contains(sid)) return false; + inactive_session_ids_.insert(sid); + active_session_ids_.erase(sid); + return true; +} + +std::vector<std::string> MiniBloggerContext::GetUserPostsHandler( + std::int64_t sid) { + std::vector<std::string> posts; + absl::ReaderMutexLock l(&lock_); + if (inactive_session_ids_.contains(sid)) { + std::cerr << "Using an inactive session id!\n"; + std::abort(); + } + if (active_session_ids_.contains(sid)) { + posts.push_back("Random post"); + } + return posts; +} + +} // namespace fuzztest::internal
diff --git a/rpc_fuzzing/testdata/mini_blogger_context.h b/rpc_fuzzing/testdata/mini_blogger_context.h new file mode 100644 index 0000000..7bde0ac --- /dev/null +++ b/rpc_fuzzing/testdata/mini_blogger_context.h
@@ -0,0 +1,59 @@ +#ifndef THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_ +#define THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_ + +#include <cstdint> +#include <optional> +#include <string> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "./rpc_fuzzing/testdata/mini_blogger.pb.h" + +namespace fuzztest::internal { + +struct MiniBloggerUser { + std::string name; + std::string email; + std::string passwd; +}; + +// The real handling logics of MiniBlogger, shared by by gRPC and other +// frameworks. +class MiniBloggerContext { + public: + // Adds a new user into the database. + void RegisterUser(const RegisterUserRequest& request, + RegisterUserResponse& response); + + // Logs in a user if the user info exists in the database and generates a + // session id. + void LogInUser(const LogInUserRequest& request, LogInUserResponse& response); + + // Gets the posts for a user given the session id. + void GetUserPosts(const GetUserPostsRequest& request, + GetUserPostsResponse& response); + + // Logs out a user given the session id. The session id will be invalidated. + void LogOutUser(const LogOutUserRequest& request, + LogOutUserResponse& response); + + private: + bool RegisterUserHandler(absl::string_view name, absl::string_view email, + absl::string_view passwd); + std::optional<std::int64_t> LoginUserHandler(absl::string_view name, + absl::string_view passwd); + bool LogOutUserHandler(std::int64_t sid); + std::vector<std::string> GetUserPostsHandler(std::int64_t sid); + + absl::flat_hash_map<std::string /*name*/, MiniBloggerUser> users_; + absl::flat_hash_set<std::int64_t> active_session_ids_; + absl::flat_hash_set<std::int64_t> inactive_session_ids_; + absl::Mutex lock_; // Serialize operation. +}; + +} // namespace fuzztest::internal + +#endif // THIRD_PARTY_GOOGLEFUZZTEST_rpc_fuzzing_RCP_TEST_SERVER_CONTEXT_H_