pw_rpc: Client-side nanopb code generation
This updates the pw_rpc's nanopb codegen to output a service client
class for every service in a proto file, containing member functions to
invoke each of the service's methods.
Change-Id: I6bfbb76982e7a8be03eb2f13d50085b781e880a2
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/19260
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index 532f641..637a078 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -127,9 +127,12 @@
pw_test("codegen_test") {
deps = [
+ ":client",
+ ":internal_test_utils",
":test_method_context",
"..:server",
"..:test_protos.nanopb_rpc",
+ "..:test_utils",
]
sources = [ "codegen_test.cc" ]
enable_if = dir_pw_third_party_nanopb != ""
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc
index d7e96b9..2ca01ab 100644
--- a/pw_rpc/nanopb/codegen_test.cc
+++ b/pw_rpc/nanopb/codegen_test.cc
@@ -15,6 +15,8 @@
#include "gtest/gtest.h"
#include "pw_rpc/internal/hash.h"
#include "pw_rpc/nanopb_test_method_context.h"
+#include "pw_rpc_nanopb_private/internal_test_utils.h"
+#include "pw_rpc_private/internal_test_utils.h"
#include "pw_rpc_test_protos/test.rpc.pb.h"
namespace pw::rpc {
@@ -50,7 +52,7 @@
EXPECT_STREQ(service.name(), "TestService");
}
-TEST(NanopbCodegen, InvokeUnaryRpc) {
+TEST(NanopbCodegen, Server_InvokeUnaryRpc) {
PW_NANOPB_TEST_METHOD_CONTEXT(test::TestService, TestRpc) context;
EXPECT_EQ(Status::Ok(),
@@ -64,7 +66,7 @@
EXPECT_EQ(1000, context.response().value);
}
-TEST(NanopbCodegen, InvokeStreamingRpc) {
+TEST(NanopbCodegen, Server_InvokeStreamingRpc) {
PW_NANOPB_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc) context;
context.call({.integer = 0, .status_code = Status::Aborted()});
@@ -86,7 +88,8 @@
EXPECT_EQ(Status::Ok(), context.status());
}
-TEST(NanopbCodegen, InvokeStreamingRpc_ContextKeepsFixedNumberOfResponses) {
+TEST(NanopbCodegen,
+ Server_InvokeStreamingRpc_ContextKeepsFixedNumberOfResponses) {
PW_NANOPB_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc, 3) context;
ASSERT_EQ(3u, context.responses().max_size());
@@ -101,7 +104,7 @@
EXPECT_EQ(context.responses()[2].number, 4u);
}
-TEST(NanopbCodegen, InvokeStreamingRpc_ManualWriting) {
+TEST(NanopbCodegen, Server_InvokeStreamingRpc_ManualWriting) {
PW_NANOPB_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc, 3) context;
ASSERT_EQ(3u, context.responses().max_size());
@@ -126,5 +129,62 @@
EXPECT_EQ(context.responses()[2].number, 9u);
}
+using TestServiceClient = test::nanopb::TestServiceClient;
+using internal::TestServerStreamingResponseHandler;
+using internal::TestUnaryResponseHandler;
+
+TEST(NanopbCodegen, Client_InvokesUnaryRpcWithCallback) {
+ constexpr uint32_t service_id = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t method_id = internal::Hash("TestRpc");
+
+ ClientContextForTest<128, 128, 99, service_id, method_id> context;
+ TestUnaryResponseHandler<pw_rpc_test_TestResponse> handler;
+
+ auto call = TestServiceClient::TestRpc(
+ context.channel(), {.integer = 123, .status_code = 0}, handler);
+ EXPECT_EQ(context.output().packet_count(), 1u);
+ auto packet = context.output().sent_packet();
+ EXPECT_EQ(packet.channel_id(), context.channel().id());
+ EXPECT_EQ(packet.service_id(), service_id);
+ EXPECT_EQ(packet.method_id(), method_id);
+ PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
+ EXPECT_EQ(sent_proto.integer, 123);
+
+ PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
+ context.SendResponse(Status::Ok(), response);
+ ASSERT_EQ(handler.responses_received(), 1u);
+ EXPECT_EQ(handler.last_status(), Status::Ok());
+ EXPECT_EQ(handler.last_response().value, 42);
+}
+
+TEST(NanopbCodegen, Client_InvokesServerStreamingRpcWithCallback) {
+ constexpr uint32_t service_id = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t method_id = internal::Hash("TestStreamRpc");
+
+ ClientContextForTest<128, 128, 99, service_id, method_id> context;
+ TestServerStreamingResponseHandler<pw_rpc_test_TestStreamResponse> handler;
+
+ auto call = TestServiceClient::TestStreamRpc(
+ context.channel(), {.integer = 123, .status_code = 0}, handler);
+ EXPECT_EQ(context.output().packet_count(), 1u);
+ auto packet = context.output().sent_packet();
+ EXPECT_EQ(packet.channel_id(), context.channel().id());
+ EXPECT_EQ(packet.service_id(), service_id);
+ EXPECT_EQ(packet.method_id(), method_id);
+ PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
+ EXPECT_EQ(sent_proto.integer, 123);
+
+ PW_ENCODE_PB(
+ pw_rpc_test_TestStreamResponse, response, .chunk = {}, .number = 11u);
+ context.SendResponse(Status::Ok(), response);
+ ASSERT_EQ(handler.responses_received(), 1u);
+ EXPECT_EQ(handler.last_response().number, 11u);
+
+ context.SendPacket(internal::PacketType::SERVER_STREAM_END,
+ Status::NotFound());
+ EXPECT_FALSE(handler.active());
+ EXPECT_EQ(handler.status(), Status::NotFound());
+}
+
} // namespace
} // namespace pw::rpc
diff --git a/pw_rpc/nanopb/nanopb_client_call_test.cc b/pw_rpc/nanopb/nanopb_client_call_test.cc
index f8374b7..a39a510 100644
--- a/pw_rpc/nanopb/nanopb_client_call_test.cc
+++ b/pw_rpc/nanopb/nanopb_client_call_test.cc
@@ -59,52 +59,8 @@
}
};
-// TODO(frolv): Maybe extract these into a utils header as it could be useful
-// for other tests.
-template <typename Response>
-class TestUnaryResponseHandler : public UnaryResponseHandler<Response> {
- public:
- void ReceivedResponse(Status status, const Response& response) override {
- last_status_ = status;
- last_response_ = response;
- ++responses_received_;
- }
-
- constexpr Status last_status() const { return last_status_; }
- constexpr const Response& last_response() const& { return last_response_; }
- constexpr size_t responses_received() const { return responses_received_; }
-
- private:
- Status last_status_;
- Response last_response_;
- size_t responses_received_ = 0;
-};
-
-template <typename Response>
-class TestServerStreamingResponseHandler
- : public ServerStreamingResponseHandler<Response> {
- public:
- void ReceivedResponse(const Response& response) override {
- last_response_ = response;
- ++responses_received_;
- }
-
- void Complete(Status status) override {
- active_ = false;
- status_ = status;
- }
-
- constexpr bool active() const { return active_; }
- constexpr Status status() const { return status_; }
- constexpr const Response& last_response() const& { return last_response_; }
- constexpr size_t responses_received() const { return responses_received_; }
-
- private:
- Status status_;
- Response last_response_;
- size_t responses_received_ = 0;
- bool active_ = true;
-};
+using internal::TestServerStreamingResponseHandler;
+using internal::TestUnaryResponseHandler;
TEST(NanopbClientCall, Unary_SendsRequestPacket) {
ClientContextForTest context;
diff --git a/pw_rpc/nanopb/pw_rpc_nanopb_private/internal_test_utils.h b/pw_rpc/nanopb/pw_rpc_nanopb_private/internal_test_utils.h
index 5312463..391b9c3 100644
--- a/pw_rpc/nanopb/pw_rpc_nanopb_private/internal_test_utils.h
+++ b/pw_rpc/nanopb/pw_rpc_nanopb_private/internal_test_utils.h
@@ -17,6 +17,7 @@
#include "pb_decode.h"
#include "pb_encode.h"
+#include "pw_rpc/nanopb_client_call.h"
namespace pw::rpc::internal {
@@ -59,4 +60,53 @@
EXPECT_TRUE(pb_decode(&input, fields, &protobuf));
}
+// Client response handler for a unary RPC invocation which captures the
+// response it receives.
+template <typename Response>
+class TestUnaryResponseHandler : public UnaryResponseHandler<Response> {
+ public:
+ void ReceivedResponse(Status status, const Response& response) override {
+ last_status_ = status;
+ last_response_ = response;
+ ++responses_received_;
+ }
+
+ constexpr Status last_status() const { return last_status_; }
+ constexpr const Response& last_response() const& { return last_response_; }
+ constexpr size_t responses_received() const { return responses_received_; }
+
+ private:
+ Status last_status_;
+ Response last_response_;
+ size_t responses_received_ = 0;
+};
+
+// Client response handler for a unary RPC invocation which stores information
+// about the state of the stream.
+template <typename Response>
+class TestServerStreamingResponseHandler
+ : public ServerStreamingResponseHandler<Response> {
+ public:
+ void ReceivedResponse(const Response& response) override {
+ last_response_ = response;
+ ++responses_received_;
+ }
+
+ void Complete(Status status) override {
+ active_ = false;
+ status_ = status;
+ }
+
+ constexpr bool active() const { return active_; }
+ constexpr Status status() const { return status_; }
+ constexpr const Response& last_response() const& { return last_response_; }
+ constexpr size_t responses_received() const { return responses_received_; }
+
+ private:
+ Status status_;
+ Response last_response_;
+ size_t responses_received_ = 0;
+ bool active_ = true;
+};
+
} // namespace pw::rpc::internal
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 1d986f6..433829a 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -87,6 +87,8 @@
output: OutputFile) -> None:
"""Generates a C++ derived class for a nanopb RPC service."""
+ output.write_line('namespace generated {')
+
base_class = f'{RPC_NAMESPACE}::Service'
output.write_line('\ntemplate <typename Implementation>')
output.write_line(
@@ -151,6 +153,82 @@
output.write_line('};')
+ output.write_line('\n} // namespace generated\n')
+
+
+def _generate_code_for_client_method(method: ProtoServiceMethod,
+ output: OutputFile) -> None:
+ """Outputs client code for a single RPC method."""
+
+ req = method.request_type().nanopb_name()
+ res = method.response_type().nanopb_name()
+ method_id = pw_rpc.ids.calculate(method.name())
+
+ if method.type() == ProtoServiceMethod.Type.UNARY:
+ callback = f'{RPC_NAMESPACE}::UnaryResponseHandler<{res}>'
+ elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING:
+ callback = f'{RPC_NAMESPACE}::ServerStreamingResponseHandler<{res}>'
+ else:
+ raise NotImplementedError(
+ 'Only unary and server streaming RPCs are currently supported')
+
+ output.write_line()
+ output.write_line(f'static NanopbClientCall<\n {callback}>')
+ output.write_line(f'{method.name()}({RPC_NAMESPACE}::Channel& channel,')
+ with output.indent(len(method.name()) + 1):
+ output.write_line(f'const {req}& request,')
+ output.write_line(f'{callback}& callback) {{')
+
+ with output.indent():
+ output.write_line(f'NanopbClientCall<{callback}>')
+ output.write_line(' call(&channel,')
+ with output.indent(9):
+ output.write_line('kServiceId,')
+ output.write_line(
+ f'0x{method_id:08x}, // Hash of "{method.name()}"')
+ output.write_line('callback,')
+ output.write_line(f'{req}_fields,')
+ output.write_line(f'{res}_fields);')
+ output.write_line('call.SendRequest(&request);')
+ output.write_line('return call;')
+
+ output.write_line('}')
+
+
+def _generate_code_for_client(service: ProtoService, root: ProtoNode,
+ output: OutputFile) -> None:
+ """Outputs client code for an RPC service."""
+
+ output.write_line('namespace nanopb {')
+
+ class_name = f'{service.cpp_namespace(root)}Client'
+ output.write_line(f'\nclass {class_name} {{')
+ output.write_line(' public:')
+
+ with output.indent():
+ output.write_line('template <typename T>')
+ output.write_line(
+ f'using NanopbClientCall = {RPC_NAMESPACE}::NanopbClientCall<T>;')
+
+ output.write_line('')
+ output.write_line(f'{class_name}() = delete;')
+
+ for method in service.methods():
+ _generate_code_for_client_method(method, output)
+
+ service_name_hash = pw_rpc.ids.calculate(service.proto_path())
+ output.write_line('\n private:')
+
+ with output.indent():
+ output.write_line(f'// Hash of "{service.proto_path()}".')
+ output.write_line(
+ f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
+ )
+
+ output.write_line('};')
+
+ output.write_line('\n} // namespace nanopb\n')
+
def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
output: OutputFile) -> None:
@@ -168,6 +246,7 @@
output.write_line('#include <cstdint>')
output.write_line('#include <type_traits>\n')
output.write_line('#include "pw_rpc/internal/nanopb_method_union.h"')
+ output.write_line('#include "pw_rpc/nanopb_client_call.h"')
output.write_line('#include "pw_rpc/server_context.h"')
output.write_line('#include "pw_rpc/service.h"')
@@ -190,14 +269,12 @@
output.write_line(f'namespace {file_namespace} {{')
- output.write_line('namespace generated {')
-
for node in package:
if node.type() == ProtoNode.Type.SERVICE:
_generate_code_for_service(cast(ProtoService, node), package,
output)
-
- output.write_line('\n} // namespace generated')
+ _generate_code_for_client(cast(ProtoService, node), package,
+ output)
if package.cpp_namespace():
output.write_line(f'}} // namespace {file_namespace}')