pw_rpc: Support testing services with mixed protobuf libraries
- Replace the public internal method lookup generated functions with a
private array of method IDs. Create a MethodLookup class that is
friended by generated services to access this array.
- Use the generated method ID array to support method lookup in services
with multiple different protobuf implementations (Nanopb and raw).
- Ensure that the correct method implementation is selected when looking
up a method at compile time. This prevents accidentally accessing a
Nanopb method as a raw method or vice versa in a test method context.
Change-Id: I0ea334c02ae7c29e50099aa5bcfa975df625d15e
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/25740
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
diff --git a/pw_rpc/BUILD b/pw_rpc/BUILD
index fe34d2d..cdd1084 100644
--- a/pw_rpc/BUILD
+++ b/pw_rpc/BUILD
@@ -45,6 +45,7 @@
"public/pw_rpc/internal/call.h",
"public/pw_rpc/internal/hash.h",
"public/pw_rpc/internal/method.h",
+ "public/pw_rpc/internal/method_lookup.h",
"public/pw_rpc/internal/method_union.h",
"public/pw_rpc/internal/server.h",
"server.cc",
@@ -103,6 +104,7 @@
srcs = [
"nanopb/codegen_test.cc",
"nanopb/echo_service_test.cc",
+ "nanopb/method_lookup_test.cc",
"nanopb/nanopb_client_call.cc",
"nanopb/nanopb_client_call_test.cc",
"nanopb/nanopb_common.cc",
@@ -116,9 +118,6 @@
"nanopb/public/pw_rpc/nanopb_client_call.h",
"nanopb/public/pw_rpc/nanopb_test_method_context.h",
"nanopb/pw_rpc_nanopb_private/internal_test_utils.h",
- "nanopb/test.pb.c",
- "nanopb/test.pb.h",
- "nanopb/test_rpc.pb.h",
],
)
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 1dee6ca..4b6d0ea 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -41,6 +41,7 @@
"public/pw_rpc/internal/call.h",
"public/pw_rpc/internal/hash.h",
"public/pw_rpc/internal/method.h",
+ "public/pw_rpc/internal/method_lookup.h",
"public/pw_rpc/internal/method_union.h",
"public/pw_rpc/internal/server.h",
"server.cc",
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index b3ee46a..baf2805 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -98,6 +98,7 @@
":client_call_test",
":codegen_test",
":echo_service_test",
+ ":method_lookup_test",
":nanopb_method_test",
":nanopb_method_union_test",
]
@@ -139,6 +140,18 @@
enable_if = dir_pw_third_party_nanopb != ""
}
+pw_test("method_lookup_test") {
+ deps = [
+ ":method",
+ ":test_method_context",
+ "..:test_protos.nanopb_rpc",
+ "..:test_utils",
+ "../raw:test_method_context",
+ ]
+ sources = [ "method_lookup_test.cc" ]
+ enable_if = dir_pw_third_party_nanopb != ""
+}
+
pw_test("nanopb_method_union_test") {
deps = [
":internal_test_utils",
diff --git a/pw_rpc/nanopb/method_lookup_test.cc b/pw_rpc/nanopb/method_lookup_test.cc
new file mode 100644
index 0000000..02137c8
--- /dev/null
+++ b/pw_rpc/nanopb/method_lookup_test.cc
@@ -0,0 +1,81 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "gtest/gtest.h"
+#include "pw_rpc/nanopb_test_method_context.h"
+#include "pw_rpc/raw_test_method_context.h"
+#include "pw_rpc_test_protos/test.rpc.pb.h"
+
+namespace pw::rpc {
+namespace {
+
+class MixedService1 : public test::generated::TestService<MixedService1> {
+ public:
+ StatusWithSize TestRpc(ServerContext&, ConstByteSpan, ByteSpan) {
+ return StatusWithSize(123);
+ }
+
+ void TestStreamRpc(ServerContext&,
+ const pw_rpc_test_TestRequest&,
+ ServerWriter<pw_rpc_test_TestStreamResponse>&) {
+ called_streaming_method = true;
+ }
+
+ bool called_streaming_method = false;
+};
+
+class MixedService2 : public test::generated::TestService<MixedService2> {
+ public:
+ Status TestRpc(ServerContext&,
+ const pw_rpc_test_TestRequest&,
+ pw_rpc_test_TestResponse&) {
+ return Status::Unauthenticated();
+ }
+
+ void TestStreamRpc(ServerContext&, ConstByteSpan, RawServerWriter&) {
+ called_streaming_method = true;
+ }
+
+ bool called_streaming_method = false;
+};
+
+TEST(MixedService1, CallRawMethod) {
+ PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestRpc) context;
+ StatusWithSize sws = context.call({});
+ EXPECT_TRUE(sws.ok());
+ EXPECT_EQ(123u, sws.size());
+}
+
+TEST(MixedService1, CallNanopbMethod) {
+ PW_NANOPB_TEST_METHOD_CONTEXT(MixedService1, TestStreamRpc) context;
+ ASSERT_FALSE(context.service().called_streaming_method);
+ context.call({});
+ EXPECT_TRUE(context.service().called_streaming_method);
+}
+
+TEST(MixedService2, CallNanopbMethod) {
+ PW_NANOPB_TEST_METHOD_CONTEXT(MixedService2, TestRpc) context;
+ Status status = context.call({});
+ EXPECT_EQ(Status::Unauthenticated(), status);
+}
+
+TEST(MixedService2, CallRawMethod) {
+ PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestStreamRpc) context;
+ ASSERT_FALSE(context.service().called_streaming_method);
+ context.call({});
+ EXPECT_TRUE(context.service().called_streaming_method);
+}
+
+} // namespace
+} // namespace pw::rpc
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
index 0ee76fe..08572e6 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
@@ -21,6 +21,7 @@
#include "pw_preprocessor/arguments.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/hash.h"
+#include "pw_rpc/internal/method_lookup.h"
#include "pw_rpc/internal/nanopb_method.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/server.h"
@@ -121,15 +122,6 @@
Status last_status_;
};
-template <typename Service, uint32_t method_id>
-constexpr const internal::NanopbMethod& GetNanopbMethod() {
- constexpr const internal::NanopbMethod* nanopb_method =
- GeneratedService<Service>::NanopbMethodFor(method_id);
- static_assert(nanopb_method != nullptr,
- "The selected function is not an RPC service method");
- return *nanopb_method;
-}
-
// Collects everything needed to invoke a particular RPC.
template <typename Service,
auto method,
@@ -142,14 +134,16 @@
template <typename... Args>
InvocationContext(Args&&... args)
- : output(GetNanopbMethod<Service, method_id>(), responses, buffer),
+ : output(MethodLookup::GetNanopbMethod<Service, method_id>(),
+ responses,
+ buffer),
channel(Channel::Create<123>(&output)),
server(std::span(&channel, 1)),
service(std::forward<Args>(args)...),
call(static_cast<internal::Server&>(server),
static_cast<internal::Channel&>(channel),
service,
- GetNanopbMethod<Service, method_id>()) {}
+ MethodLookup::GetNanopbMethod<Service, method_id>()) {}
MessageOutput<Response> output;
@@ -176,6 +170,8 @@
template <typename... Args>
UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
+ Service& service() { return ctx_.service; }
+
// Invokes the RPC with the provided request. Returns the status.
Status call(const Request& request) {
ctx_.output.clear();
@@ -210,6 +206,8 @@
template <typename... Args>
ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
+ Service& service() { return ctx_.service; }
+
// Invokes the RPC with the provided request.
void call(const Request& request) {
ctx_.output.clear();
diff --git a/pw_rpc/public/pw_rpc/internal/method_lookup.h b/pw_rpc/public/pw_rpc/internal/method_lookup.h
new file mode 100644
index 0000000..cd7ddcb
--- /dev/null
+++ b/pw_rpc/public/pw_rpc/internal/method_lookup.h
@@ -0,0 +1,64 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include "pw_rpc/internal/method.h"
+
+namespace pw::rpc::internal {
+
+// Gets a Method object from a generated RPC service class. Getter functions are
+// provided for each supported method implementation. The
+//
+// To ensure the MethodUnion actually holds the requested method type, the
+// method ID is accessed in a static_assert. It is invalid to access an unset
+// union member in a constant expression, so this results in a compiler error.
+class MethodLookup {
+ public:
+ template <typename Service, uint32_t method_id>
+ static constexpr const auto& GetRawMethod() {
+ const auto& method = GetMethodUnion<Service, method_id>().raw_method();
+ static_assert(method.id() == method_id, "Incorrect method implementation");
+ return method;
+ }
+
+ template <typename Service, uint32_t method_id>
+ static constexpr const auto& GetNanopbMethod() {
+ const auto& method = GetMethodUnion<Service, method_id>().nanopb_method();
+ static_assert(method.id() == method_id, "Incorrect method implementation");
+ return method;
+ }
+
+ private:
+ template <typename Service, uint32_t method_id>
+ static constexpr const auto& GetMethodUnion() {
+ constexpr auto method = GetMethodUnionPointer<Service>(method_id);
+ static_assert(method != nullptr,
+ "The selected function is not an RPC service method");
+ return *method;
+ }
+
+ template <typename Service>
+ static constexpr
+ typename decltype(GeneratedService<Service>::kMethods)::const_pointer
+ GetMethodUnionPointer(uint32_t method_id) {
+ for (size_t i = 0; i < GeneratedService<Service>::kMethodIds.size(); ++i) {
+ if (GeneratedService<Service>::kMethodIds[i] == method_id) {
+ return &GeneratedService<Service>::kMethods[i];
+ }
+ }
+ return nullptr;
+ }
+};
+
+} // namespace pw::rpc::internal
diff --git a/pw_rpc/py/BUILD.gn b/pw_rpc/py/BUILD.gn
index 56ca19d..03f7b29 100644
--- a/pw_rpc/py/BUILD.gn
+++ b/pw_rpc/py/BUILD.gn
@@ -22,6 +22,7 @@
"pw_rpc/__init__.py",
"pw_rpc/callback_client.py",
"pw_rpc/client.py",
+ "pw_rpc/codegen.py",
"pw_rpc/codegen_nanopb.py",
"pw_rpc/codegen_raw.py",
"pw_rpc/descriptors.py",
diff --git a/pw_rpc/py/pw_rpc/codegen.py b/pw_rpc/py/pw_rpc/codegen.py
new file mode 100644
index 0000000..7dd5451
--- /dev/null
+++ b/pw_rpc/py/pw_rpc/codegen.py
@@ -0,0 +1,33 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Common RPC codegen utilities."""
+
+from pw_protobuf.output_file import OutputFile
+from pw_protobuf.proto_tree import ProtoService
+
+import pw_rpc.ids
+
+
+def method_lookup_table(service: ProtoService, output: OutputFile) -> None:
+ """Generates array of method IDs for looking up methods at compile time."""
+ output.write_line('static constexpr std::array<uint32_t, '
+ f'{len(service.methods())}> kMethodIds = {{')
+
+ with output.indent(4):
+ for method in service.methods():
+ method_id = pw_rpc.ids.calculate(method.name())
+ output.write_line(
+ f'0x{method_id:08x}, // Hash of {method.name()}')
+
+ output.write_line('};\n')
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 0e33498..2d0bd7d 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -21,6 +21,7 @@
from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
from pw_protobuf.proto_tree import build_node_tree
import pw_rpc.ids
+from pw_rpc import codegen
PLUGIN_NAME = 'pw_rpc_codegen'
PLUGIN_VERSION = '0.1.0'
@@ -61,29 +62,6 @@
output.write_line(f'{res_fields}),')
-def _generate_method_lookup_function(output: OutputFile):
- """Generates a function that gets a Method object from its ID."""
- nanopb_method = f'{RPC_NAMESPACE}::internal::NanopbMethod'
-
- output.write_line(
- f'static constexpr const {nanopb_method}* NanopbMethodFor(')
- output.write_line(' uint32_t id) {')
-
- with output.indent():
- output.write_line('for (auto& method : kMethods) {')
- with output.indent():
- output.write_line('if (method.nanopb_method().id() == id) {')
- output.write_line(
- f' return &static_cast<const {nanopb_method}&>(')
- output.write_line(' method.nanopb_method());')
- output.write_line('}')
- output.write_line('}')
-
- output.write_line('return nullptr;')
-
- output.write_line('}')
-
-
def _generate_code_for_service(service: ProtoService, root: ProtoNode,
output: OutputFile) -> None:
"""Generates a C++ derived class for a nanopb RPC service."""
@@ -123,13 +101,11 @@
output.write_line(
'constexpr void _PwRpcInternalGeneratedBase() const {}')
- output.write_line()
- _generate_method_lookup_function(output)
-
service_name_hash = pw_rpc.ids.calculate(service.proto_path())
output.write_line('\n private:')
with output.indent():
+ output.write_line('friend class ::pw::rpc::internal::MethodLookup;\n')
output.write_line(f'// Hash of "{service.proto_path()}".')
output.write_line(
f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
@@ -146,7 +122,10 @@
for method in service.methods():
_generate_method_descriptor(method, output)
- output.write_line('};')
+ output.write_line('};\n')
+
+ # Generate the method lookup table
+ codegen.method_lookup_table(service, output)
output.write_line('};')
@@ -242,6 +221,7 @@
output.write_line('#include <cstddef>')
output.write_line('#include <cstdint>')
output.write_line('#include <type_traits>\n')
+ output.write_line('#include "pw_rpc/internal/method_lookup.h"')
output.write_line('#include "pw_rpc/internal/nanopb_method_union.h"')
output.write_line('#include "pw_rpc/nanopb_client_call.h"')
output.write_line('#include "pw_rpc/server_context.h"')
diff --git a/pw_rpc/py/pw_rpc/codegen_raw.py b/pw_rpc/py/pw_rpc/codegen_raw.py
index 6a271cf..9b25bfe 100644
--- a/pw_rpc/py/pw_rpc/codegen_raw.py
+++ b/pw_rpc/py/pw_rpc/codegen_raw.py
@@ -21,6 +21,7 @@
from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
from pw_protobuf.proto_tree import build_node_tree
import pw_rpc.ids
+from pw_rpc import codegen
PLUGIN_NAME = 'pw_rpc_codegen'
PLUGIN_VERSION = '0.1.0'
@@ -49,27 +50,6 @@
output.write_line(f' 0x{method_id:08x}), // Hash of "{method.name()}"')
-def _generate_method_lookup_function(output: OutputFile):
- """Generates a function that gets a Method object from its ID."""
- raw_method = f'{RPC_NAMESPACE}::internal::RawMethod'
-
- output.write_line(f'static constexpr const {raw_method}* RawMethodFor(')
- output.write_line(' uint32_t id) {')
-
- with output.indent():
- output.write_line('for (auto& method : kMethods) {')
- with output.indent():
- output.write_line('if (method.raw_method().id() == id) {')
- output.write_line(f' return &static_cast<const {raw_method}&>(')
- output.write_line(' method.raw_method());')
- output.write_line('}')
- output.write_line('}')
-
- output.write_line('return nullptr;')
-
- output.write_line('}')
-
-
def _generate_code_for_service(service: ProtoService, root: ProtoNode,
output: OutputFile) -> None:
"""Generates a C++ base class for a raw RPC service."""
@@ -106,13 +86,11 @@
output.write_line(
'constexpr void _PwRpcInternalGeneratedBase() const {}')
- output.write_line()
- _generate_method_lookup_function(output)
-
service_name_hash = pw_rpc.ids.calculate(service.proto_path())
output.write_line('\n private:')
with output.indent():
+ output.write_line('friend class ::pw::rpc::internal::MethodLookup;\n')
output.write_line(f'// Hash of "{service.proto_path()}".')
output.write_line(
f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
@@ -131,6 +109,9 @@
output.write_line('};')
+ # Generate the method lookup table
+ codegen.method_lookup_table(service, output)
+
output.write_line('};')
@@ -147,6 +128,7 @@
output.write_line('#include <cstddef>')
output.write_line('#include <cstdint>')
output.write_line('#include <type_traits>\n')
+ output.write_line('#include "pw_rpc/internal/method_lookup.h"')
output.write_line('#include "pw_rpc/internal/raw_method_union.h"')
output.write_line('#include "pw_rpc/server_context.h"')
output.write_line('#include "pw_rpc/service.h"\n')
diff --git a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
index 371e0f5..5e13a95 100644
--- a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
+++ b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
@@ -20,6 +20,7 @@
#include "pw_containers/vector.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/hash.h"
+#include "pw_rpc/internal/method_lookup.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/raw_method.h"
#include "pw_rpc/internal/server.h"
@@ -131,15 +132,6 @@
Status last_status_;
};
-template <typename Service, uint32_t method_id>
-constexpr const internal::RawMethod& GetRawMethod() {
- constexpr const internal::RawMethod* raw_method =
- GeneratedService<Service>::RawMethodFor(method_id);
- static_assert(raw_method != nullptr,
- "The selected function is not an RPC service method");
- return *raw_method;
-}
-
// Collects everything needed to invoke a particular RPC.
template <typename Service,
uint32_t method_id,
@@ -155,7 +147,7 @@
call(static_cast<internal::Server&>(server),
static_cast<internal::Channel&>(channel),
service,
- GetRawMethod<Service, method_id>()) {}
+ MethodLookup::GetRawMethod<Service, method_id>()) {}
using ResponseBuffer = std::array<std::byte, output_size>;