pw_rpc: Update C++ client API
This updates the generated service client API to work with an RPC client
and channel ID instead of a channel itself. This will allow removing the
client reference from the channel object.
As part of this, the generated client classes are made instantiable to
improve usability. A client and channel ID must be passed in when
creating a service client; beyond this, they are not needed to invoke
methods.
Change-Id: I0cb29169ccae648b75f6e1d523cca5fe49bd00a1
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/60061
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/BUILD.bazel b/pw_rpc/BUILD.bazel
index 69525da..13a76b0 100644
--- a/pw_rpc/BUILD.bazel
+++ b/pw_rpc/BUILD.bazel
@@ -36,6 +36,7 @@
deps = [
":common",
"//pw_containers:intrusive_list",
+ "//pw_result",
],
)
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index d0a6346..12ff950 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -75,7 +75,10 @@
pw_source_set("client") {
public_configs = [ ":public_include_path" ]
- public_deps = [ ":common" ]
+ public_deps = [
+ ":common",
+ dir_pw_result,
+ ]
deps = [ dir_pw_log ]
public = [
"public/pw_rpc/client.h",
diff --git a/pw_rpc/CMakeLists.txt b/pw_rpc/CMakeLists.txt
index 9e10afd..2a164ca 100644
--- a/pw_rpc/CMakeLists.txt
+++ b/pw_rpc/CMakeLists.txt
@@ -40,6 +40,7 @@
client.cc
PUBLIC_DEPS
pw_rpc.common
+ pw_result
PRIVATE_DEPS
pw_log
)
diff --git a/pw_rpc/base_client_call.cc b/pw_rpc/base_client_call.cc
index 103ae03..4c654e9 100644
--- a/pw_rpc/base_client_call.cc
+++ b/pw_rpc/base_client_call.cc
@@ -24,7 +24,8 @@
Unregister();
active_ = other.active_;
- channel_ = other.channel_;
+ client_ = other.client_;
+ channel_id_ = other.channel_id_;
service_id_ = other.service_id_;
method_id_ = other.method_id_;
request_ = std::move(other.request_);
@@ -34,18 +35,26 @@
// If the call being assigned is active, replace it in the client's list
// with a reference to the current object.
other.Unregister();
- other.channel_->client()
- ->RegisterCall(*this)
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
+
+ // RegisterCall() only fails if there is already a call for the same
+ // (channel, service, method). As the existing call is unregistered, this
+ // error cannot happen.
+ other.client_->RegisterCall(*this).IgnoreError();
}
return *this;
}
void BaseClientCall::Cancel() {
- if (active()) {
- channel_->Send(NewPacket(PacketType::CANCEL))
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
+ if (!active()) {
+ return;
+ }
+
+ rpc::Channel* channel = client_->GetChannel(channel_id_);
+ if (channel != nullptr) {
+ static_cast<Channel*>(channel)
+ ->Send(NewPacket(PacketType::CANCEL))
+ .IgnoreError();
}
}
@@ -54,7 +63,12 @@
return {};
}
- request_ = channel_->AcquireBuffer();
+ rpc::Channel* channel = client_->GetChannel(channel_id_);
+ if (!channel) {
+ return {};
+ }
+
+ request_ = static_cast<Channel*>(channel)->AcquireBuffer();
return request_.payload(NewPacket(PacketType::REQUEST));
}
@@ -64,23 +78,30 @@
return Status::FailedPrecondition();
}
- return channel_->Send(request_, NewPacket(PacketType::REQUEST, payload));
+ rpc::Channel* channel = client_->GetChannel(channel_id_);
+ if (!channel) {
+ return Status::NotFound();
+ }
+
+ return static_cast<Channel*>(channel)->Send(
+ request_, NewPacket(PacketType::REQUEST, payload));
}
Packet BaseClientCall::NewPacket(PacketType type,
std::span<const std::byte> payload) const {
- return Packet(type, channel_->id(), service_id_, method_id_, payload);
+ return Packet(type, channel_id_, service_id_, method_id_, payload);
}
void BaseClientCall::Register() {
- channel_->client()
- ->RegisterCall(*this)
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
+ // TODO(frolv): This is broken. If you try to replace an exisitng call with a
+ // new call to the same method on the same channel, the new call will fail to
+ // register. Instead, the new call should replace the existing one.
+ client_->RegisterCall(*this).IgnoreError();
}
void BaseClientCall::Unregister() {
if (active()) {
- channel_->client()->RemoveCall(*this);
+ client_->RemoveCall(*this);
active_ = false;
}
}
diff --git a/pw_rpc/base_client_call_test.cc b/pw_rpc/base_client_call_test.cc
index e4f0dff..7583966 100644
--- a/pw_rpc/base_client_call_test.cc
+++ b/pw_rpc/base_client_call_test.cc
@@ -25,7 +25,8 @@
EXPECT_EQ(context.client().active_calls(), 0u);
{
- BaseClientCall call(&context.channel(),
+ BaseClientCall call(&context.client(),
+ context.channel().id(),
context.service_id(),
context.method_id(),
[](BaseClientCall&, const Packet&) {});
@@ -39,7 +40,8 @@
ClientContextForTest context;
EXPECT_EQ(context.client().active_calls(), 0u);
- BaseClientCall moved(&context.channel(),
+ BaseClientCall moved(&context.client(),
+ context.channel().id(),
context.service_id(),
context.method_id(),
[](BaseClientCall&, const Packet&) {});
@@ -76,11 +78,12 @@
class FakeClientCall : public BaseClientCall {
public:
- constexpr FakeClientCall(rpc::Channel* channel,
+ constexpr FakeClientCall(Client* client,
+ uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
ResponseHandler handler)
- : BaseClientCall(channel, service_id, method_id, handler) {}
+ : BaseClientCall(client, channel_id, service_id, method_id, handler) {}
Status SendPacket(std::span<const std::byte> payload) {
std::span buffer = AcquirePayloadBuffer();
@@ -91,7 +94,8 @@
TEST(BaseClientCall, SendsPacketWithPayload) {
ClientContextForTest context;
- FakeClientCall call(&context.channel(),
+ FakeClientCall call(&context.client(),
+ context.channel().id(),
context.service_id(),
context.method_id(),
[](BaseClientCall&, const Packet&) {});
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index f8c1b29..3679e33 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -46,7 +46,7 @@
}
auto call = std::find_if(calls_.begin(), calls_.end(), [&](auto& c) {
- return c.channel().id() == packet.channel_id() &&
+ return c.channel_id() == packet.channel_id() &&
c.service_id() == packet.service_id() &&
c.method_id() == packet.method_id();
});
@@ -86,9 +86,21 @@
return OkStatus();
}
+Channel* Client::GetChannel(uint32_t channel_id) const {
+ auto channel = std::find_if(channels_.begin(), channels_.end(), [&](auto& c) {
+ return c.id() == channel_id;
+ });
+
+ if (channel == channels_.end()) {
+ return nullptr;
+ }
+
+ return channel;
+}
+
Status Client::RegisterCall(BaseClientCall& call) {
auto existing_call = std::find_if(calls_.begin(), calls_.end(), [&](auto& c) {
- return c.channel().id() == call.channel().id() &&
+ return c.channel_id() == call.channel_id() &&
c.service_id() == call.service_id() &&
c.method_id() == call.method_id();
});
diff --git a/pw_rpc/docs.rst b/pw_rpc/docs.rst
index c34563c..5e1a512 100644
--- a/pw_rpc/docs.rst
+++ b/pw_rpc/docs.rst
@@ -708,6 +708,9 @@
// Generated clients are namespaced with their proto library.
using pw::rpc::nanopb::EchoServiceClient;
+ // RPC channel ID on which to make client calls.
+ constexpr uint32_t kDefaultChannelId = 1;
+
EchoServiceClient::EchoCall echo_call;
// Callback invoked when a response is received. This is called synchronously
@@ -724,15 +727,21 @@
} // namespace
-
void CallEcho(const char* message) {
+ // Create a client to call the EchoService.
+ EchoServiceClient echo_client(my_rpc_client, kDefaultChannelId);
+
pw_rpc_EchoMessage request = pw_rpc_EchoMessage_init_default;
pw::string::Copy(message, request.msg);
// By assigning the returned ClientCall to the global echo_call, the RPC
// call is kept alive until it completes. When a response is received, it
// will be logged by the handler function and the call will complete.
- echo_call = EchoServiceClient::Echo(my_channel, request, EchoResponse);
+ echo_call = echo_client.Echo(request, EchoResponse);
+ if (!echo_call.active()) {
+ // The RPC call was not sent. This could occur due to, for example, an
+ // invalid channel ID. Handle if necessary.
+ }
}
Client implementation details
diff --git a/pw_rpc/nanopb/client_call.cc b/pw_rpc/nanopb/client_call.cc
index e7b4939..bc1da47 100644
--- a/pw_rpc/nanopb/client_call.cc
+++ b/pw_rpc/nanopb/client_call.cc
@@ -19,14 +19,28 @@
Status BaseNanopbClientCall::SendRequest(const void* request_struct) {
std::span<std::byte> buffer = AcquirePayloadBuffer();
-
- StatusWithSize sws = serde_.EncodeRequest(request_struct, buffer);
- if (!sws.ok()) {
- ReleasePayloadBuffer({});
- return sws.status();
+ if (buffer.empty()) {
+ // Getting an empty buffer means that either the call is inactive or the
+ // channel does not exist.
+ Unregister();
+ return Status::Unavailable();
}
- return ReleasePayloadBuffer(buffer.first(sws.size()));
+ StatusWithSize sws = serde_.EncodeRequest(request_struct, buffer);
+ Status status = sws.status();
+
+ if (status.ok()) {
+ status = ReleasePayloadBuffer(buffer.first(sws.size()));
+ } else {
+ ReleasePayloadBuffer({});
+ }
+
+ if (!status.ok()) {
+ // Failing to send the initial request ends the RPC call.
+ Unregister();
+ }
+
+ return status;
}
} // namespace internal
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc
index a61394b..2623243 100644
--- a/pw_rpc/nanopb/codegen_test.cc
+++ b/pw_rpc/nanopb/codegen_test.cc
@@ -214,29 +214,32 @@
}
TEST(NanopbCodegen, Client_InvokesUnaryRpcWithCallback) {
- constexpr uint32_t service_id = internal::Hash("pw.rpc.test.TestService");
- constexpr uint32_t method_id = internal::Hash("TestUnaryRpc");
+ constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t kMethodId = internal::Hash("TestUnaryRpc");
- ClientContextForTest<128, 128, 99, service_id, method_id> context;
+ ClientContextForTest<128, 128, 99, kServiceId, kMethodId> context;
+
+ TestServiceClient test_client(context.client(), context.channel().id());
struct {
Status last_status = Status::Unknown();
int response_value = -1;
} result;
- auto call = TestServiceClient::TestUnaryRpc(
- context.channel(),
+ auto call = test_client.TestUnaryRpc(
{.integer = 123, .status_code = 0},
[&result](const pw_rpc_test_TestResponse& response, Status status) {
result.last_status = status;
result.response_value = response.value;
});
+ EXPECT_TRUE(call.active());
+
EXPECT_EQ(context.output().packet_count(), 1u);
auto packet = context.output().sent_packet();
EXPECT_EQ(packet.channel_id(), context.channel().id());
- EXPECT_EQ(packet.service_id(), service_id);
- EXPECT_EQ(packet.method_id(), method_id);
+ EXPECT_EQ(packet.service_id(), kServiceId);
+ EXPECT_EQ(packet.method_id(), kMethodId);
PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
EXPECT_EQ(sent_proto.integer, 123);
@@ -247,10 +250,12 @@
}
TEST(NanopbCodegen, Client_InvokesServerStreamingRpcWithCallback) {
- constexpr uint32_t service_id = internal::Hash("pw.rpc.test.TestService");
- constexpr uint32_t method_id = internal::Hash("TestServerStreamRpc");
+ constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t kMethodId = internal::Hash("TestServerStreamRpc");
- ClientContextForTest<128, 128, 99, service_id, method_id> context;
+ ClientContextForTest<128, 128, 99, kServiceId, kMethodId> context;
+
+ TestServiceClient test_client(context.client(), context.channel().id());
struct {
bool active = true;
@@ -258,8 +263,7 @@
int response_value = -1;
} result;
- auto call = TestServiceClient::TestServerStreamRpc(
- context.channel(),
+ auto call = test_client.TestServerStreamRpc(
{.integer = 123, .status_code = 0},
[&result](const pw_rpc_test_TestStreamResponse& response) {
result.active = true;
@@ -270,11 +274,13 @@
result.stream_status = status;
});
+ EXPECT_TRUE(call.active());
+
EXPECT_EQ(context.output().packet_count(), 1u);
auto packet = context.output().sent_packet();
EXPECT_EQ(packet.channel_id(), context.channel().id());
- EXPECT_EQ(packet.service_id(), service_id);
- EXPECT_EQ(packet.method_id(), method_id);
+ EXPECT_EQ(packet.service_id(), kServiceId);
+ EXPECT_EQ(packet.method_id(), kMethodId);
PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
EXPECT_EQ(sent_proto.integer, 123);
@@ -289,5 +295,119 @@
EXPECT_EQ(result.stream_status, Status::NotFound());
}
+TEST(NanopbCodegen, Client_StaticMethod_InvokesUnaryRpcWithCallback) {
+ constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t kMethodId = internal::Hash("TestUnaryRpc");
+
+ ClientContextForTest<128, 128, 99, kServiceId, kMethodId> context;
+
+ struct {
+ Status last_status = Status::Unknown();
+ int response_value = -1;
+ } result;
+
+ auto call = TestServiceClient::TestUnaryRpc(
+ context.client(),
+ context.channel().id(),
+ {.integer = 123, .status_code = 0},
+ [&result](const pw_rpc_test_TestResponse& response, Status status) {
+ result.last_status = status;
+ result.response_value = response.value;
+ });
+
+ EXPECT_TRUE(call.active());
+
+ EXPECT_EQ(context.output().packet_count(), 1u);
+ auto packet = context.output().sent_packet();
+ EXPECT_EQ(packet.channel_id(), context.channel().id());
+ EXPECT_EQ(packet.service_id(), kServiceId);
+ EXPECT_EQ(packet.method_id(), kMethodId);
+ PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
+ EXPECT_EQ(sent_proto.integer, 123);
+
+ PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
+ context.SendResponse(OkStatus(), response);
+ EXPECT_EQ(result.last_status, OkStatus());
+ EXPECT_EQ(result.response_value, 42);
+}
+
+TEST(NanopbCodegen, Client_StaticMethod_InvokesServerStreamingRpcWithCallback) {
+ constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t kMethodId = internal::Hash("TestServerStreamRpc");
+
+ ClientContextForTest<128, 128, 99, kServiceId, kMethodId> context;
+
+ struct {
+ bool active = true;
+ Status stream_status = Status::Unknown();
+ int response_value = -1;
+ } result;
+
+ auto call = TestServiceClient::TestServerStreamRpc(
+ context.client(),
+ context.channel().id(),
+ {.integer = 123, .status_code = 0},
+ [&result](const pw_rpc_test_TestStreamResponse& response) {
+ result.active = true;
+ result.response_value = response.number;
+ },
+ [&result](Status status) {
+ result.active = false;
+ result.stream_status = status;
+ });
+
+ EXPECT_TRUE(call.active());
+
+ EXPECT_EQ(context.output().packet_count(), 1u);
+ auto packet = context.output().sent_packet();
+ EXPECT_EQ(packet.channel_id(), context.channel().id());
+ EXPECT_EQ(packet.service_id(), kServiceId);
+ EXPECT_EQ(packet.method_id(), kMethodId);
+ PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
+ EXPECT_EQ(sent_proto.integer, 123);
+
+ PW_ENCODE_PB(
+ pw_rpc_test_TestStreamResponse, response, .chunk = {}, .number = 11u);
+ context.SendServerStream(response);
+ EXPECT_TRUE(result.active);
+ EXPECT_EQ(result.response_value, 11);
+
+ context.SendResponse(Status::NotFound());
+ EXPECT_FALSE(result.active);
+ EXPECT_EQ(result.stream_status, Status::NotFound());
+}
+
+TEST(NanopbCodegen, Client_InvalidChannel_DoesntMakeCall) {
+ constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService");
+ constexpr uint32_t kMethodId = internal::Hash("TestUnaryRpc");
+
+ constexpr uint32_t kBadChannelId = 97;
+
+ ClientContextForTest<128, 128, 99, kServiceId, kMethodId> context;
+ TestServiceClient test_client(context.client(), kBadChannelId);
+
+ struct {
+ Status last_status = Status::Unknown();
+ int response_value = -1;
+ } result;
+
+ Status error = Status::Unknown();
+
+ auto call = test_client.TestUnaryRpc(
+ {.integer = 123, .status_code = 0},
+ [&result](const pw_rpc_test_TestResponse& response, Status status) {
+ result.last_status = status;
+ result.response_value = response.value;
+ },
+ [&error](Status rpc_error) { error = rpc_error; });
+
+ EXPECT_FALSE(call.active());
+
+ EXPECT_EQ(context.output().packet_count(), 0u);
+ EXPECT_EQ(result.last_status, Status::Unknown());
+ EXPECT_EQ(result.response_value, -1);
+ EXPECT_EQ(error, Status::Unavailable());
+}
+
} // namespace
} // namespace pw::rpc
diff --git a/pw_rpc/nanopb/docs.rst b/pw_rpc/nanopb/docs.rst
index 859e466..876a95b 100644
--- a/pw_rpc/nanopb/docs.rst
+++ b/pw_rpc/nanopb/docs.rst
@@ -140,25 +140,35 @@
Client-side
-----------
A corresponding client class is generated for every service defined in the proto
-file. Like the service class, it is placed under the ``generated`` namespace.
-The class is named after the service, with a ``Client`` suffix. For example, the
-``ChatService`` would create a ``generated::ChatServiceClient``.
+file. To allow multiple types of clients to exist, it is placed under the
+``nanopb`` namespace. The class is named after the service, with a ``Client``
+suffix. For example, the ``ChatService`` would create a
+``nanopb::ChatServiceClient``.
-The client class contains static methods to call each of the service's methods.
-It is not meant to be instantiated.
+Service clients are instantiated with a reference to the RPC client through
+which they will send requests, and the channel ID they will use.
.. code-block:: c++
- static GetRoomInformationCall GetRoomInformation(
- Channel& channel,
- const RoomInfoRequest& request,
- ::pw::Function<void(Status, const RoomInfoResponse&)> on_response,
- ::pw::Function<void(Status)> on_rpc_error = nullptr);
+ class ChatServiceClient {
+ public:
+ ChatServiceClient(::pw::rpc::Client& client, uint32_t default_channel_id);
-The ``NanopbClientCall`` object returned by the RPC invocation stores the active
-RPC's context. For more information on ``ClientCall`` objects, refer to the
-:ref:`core RPC documentation <module-pw_rpc-making-calls>`. The type of the
-returned object is complex, so it is aliased using the method name.
+ GetRoomInformationCall GetRoomInformation(
+ const RoomInfoRequest& request,
+ ::pw::Function<void(Status, const RoomInfoResponse&)> on_response,
+ ::pw::Function<void(Status)> on_rpc_error = nullptr);
+
+ // ...and more (see below).
+ };
+
+The client class has member functions for each method defined within the
+service's protobuf descriptor. The arguments to these methods vary depending on
+the type of RPC. Each method returns a ``NanopbClientCall`` object which stores
+the context of the ongoing RPC call. For more information on ``ClientCall``
+objects, refer to the :ref:`core RPC docs <module-pw_rpc-making-calls>`. The
+type of the returned object is complex, so it is aliased using the method
+name.
.. admonition:: Callback invocation
@@ -166,8 +176,7 @@
Method APIs
^^^^^^^^^^^
-The first argument to each client call method is the channel through which to
-send the RPC. Following that, the arguments depend on the method type.
+The arguments provided when invoking a method depend on its type.
Unary RPC
~~~~~~~~~
@@ -179,8 +188,7 @@
.. code-block:: c++
- static GetRoomInformationCall GetRoomInformation(
- Channel& channel,
+ GetRoomInformationCall GetRoomInformation(
const RoomInfoRequest& request,
::pw::Function<void(const RoomInfoResponse&, Status)> on_response,
::pw::Function<void(Status)> on_rpc_error = nullptr);
@@ -195,8 +203,7 @@
.. code-block:: c++
- static ListUsersInRoomCall ListUsersInRoom(
- Channel& channel,
+ ListUsersInRoomCall ListUsersInRoom(
const ListUsersRequest& request,
::pw::Function<void(const ListUsersResponse&)> on_response,
::pw::Function<void(Status)> on_stream_end,
@@ -213,7 +220,7 @@
namespace {
MyChannelOutput output;
- pw::rpc::Channel channels[] = {pw::rpc::Channel::Create<0>(&output)};
+ pw::rpc::Channel channels[] = {pw::rpc::Channel::Create<1>(&output)};
pw::rpc::Client client(channels);
// Callback function for GetRoomInformation.
@@ -221,13 +228,22 @@
}
void InvokeSomeRpcs() {
- // The RPC will remain active as long as `call` is alive.
- auto call = ChatServiceClient::GetRoomInformation(channels[0],
- {.room = "pigweed"},
- LogRoomInformation);
+ // Instantiate a service client to call ChatService methods on channel 1.
+ ChatServiceClient chat_client(client, 1);
- // For simplicity, block here. An actual implementation would likely
- // std::move the call somewhere to keep it active while doing other work.
+ // The RPC will remain active as long as `call` is alive.
+ auto call = chat_client.GetRoomInformation(
+ {.room = "pigweed"}, LogRoomInformation);
+ if (!call.active()) {
+ // The invocation may fail. This could occur due to an invalid channel ID,
+ // for example. The failure status is forwarded to the to call's
+ // on_rpc_error callback.
+ return;
+ }
+
+ // For simplicity, block until the call completes. An actual implementation
+ // would likely std::move the call somewhere to keep it active while doing
+ // other work.
while (call.active()) {
Wait();
}
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_call.h b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_call.h
index 1a20bb5..18ae943 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_call.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_call.h
@@ -31,6 +31,7 @@
Status SendRequest(const void* request_struct);
protected:
+ // TODO(frolv): Migrate everything to the new constructor and deprecate this.
constexpr BaseNanopbClientCall(rpc::Channel* channel,
uint32_t service_id,
uint32_t method_id,
@@ -40,6 +41,16 @@
: BaseClientCall(channel, service_id, method_id, handler),
serde_(request_fields, response_fields) {}
+ constexpr BaseNanopbClientCall(Client* client,
+ uint32_t client_id,
+ uint32_t service_id,
+ uint32_t method_id,
+ ResponseHandler handler,
+ NanopbMessageDescriptor request_fields,
+ NanopbMessageDescriptor response_fields)
+ : BaseClientCall(client, client_id, service_id, method_id, handler),
+ serde_(request_fields, response_fields) {}
+
constexpr BaseNanopbClientCall()
: BaseClientCall(), serde_(nullptr, nullptr) {}
@@ -60,7 +71,7 @@
ErrorCallbacks(Function<void(Status)> error) : rpc_error(std::move(error)) {}
- void InvokeRpcError(Status status) {
+ void InvokeRpcError(Status status) const {
if (rpc_error != nullptr) {
rpc_error(status);
}
@@ -106,6 +117,7 @@
template <typename Callbacks>
class NanopbClientCall : public internal::BaseNanopbClientCall {
public:
+ // TODO(frolv): Migrate everything to the new constructor and deprecate this.
constexpr NanopbClientCall(Channel* channel,
uint32_t service_id,
uint32_t method_id,
@@ -120,6 +132,22 @@
response_fields),
callbacks_(std::move(callbacks)) {}
+ constexpr NanopbClientCall(Client* client,
+ uint32_t client_id,
+ uint32_t service_id,
+ uint32_t method_id,
+ Callbacks callbacks,
+ internal::NanopbMessageDescriptor request_fields,
+ internal::NanopbMessageDescriptor response_fields)
+ : BaseNanopbClientCall(client,
+ client_id,
+ service_id,
+ method_id,
+ &ResponseHandler,
+ request_fields,
+ response_fields),
+ callbacks_(std::move(callbacks)) {}
+
constexpr NanopbClientCall() : BaseNanopbClientCall(), callbacks_({}) {}
NanopbClientCall(const NanopbClientCall&) = delete;
@@ -128,6 +156,8 @@
NanopbClientCall(NanopbClientCall&&) = default;
NanopbClientCall& operator=(NanopbClientCall&&) = default;
+ constexpr const Callbacks& callbacks() const { return callbacks_; }
+
private:
using Response = typename Callbacks::Response;
diff --git a/pw_rpc/public/pw_rpc/client.h b/pw_rpc/public/pw_rpc/client.h
index 6e7def1..cac5825 100644
--- a/pw_rpc/public/pw_rpc/client.h
+++ b/pw_rpc/public/pw_rpc/client.h
@@ -47,6 +47,8 @@
//
Status ProcessPacket(ConstByteSpan data);
+ Channel* GetChannel(uint32_t channel_id) const;
+
size_t active_calls() const { return calls_.size(); }
private:
diff --git a/pw_rpc/public/pw_rpc/internal/base_client_call.h b/pw_rpc/public/pw_rpc/internal/base_client_call.h
index 3b4e110..7dcaed8 100644
--- a/pw_rpc/public/pw_rpc/internal/base_client_call.h
+++ b/pw_rpc/public/pw_rpc/internal/base_client_call.h
@@ -28,21 +28,38 @@
public:
using ResponseHandler = void (*)(BaseClientCall&, const Packet&);
+ // TODO(frolv): Migrate everything to the new constructor and deprecate this.
constexpr BaseClientCall(rpc::Channel* channel,
uint32_t service_id,
uint32_t method_id,
ResponseHandler handler)
- : channel_(static_cast<Channel*>(channel)),
+ : client_(static_cast<Channel*>(channel)->client()),
+ channel_id_(channel->id()),
service_id_(service_id),
method_id_(method_id),
handler_(handler),
active_(true) {
- PW_ASSERT(channel_ != nullptr);
+ Register();
+ }
+
+ constexpr BaseClientCall(Client* client,
+ uint32_t channel_id,
+ uint32_t service_id,
+ uint32_t method_id,
+ ResponseHandler handler)
+ : client_(client),
+ channel_id_(channel_id),
+ service_id_(service_id),
+ method_id_(method_id),
+ handler_(handler),
+ active_(true) {
+ PW_ASSERT(client_ != nullptr);
Register();
}
constexpr BaseClientCall()
- : channel_(nullptr),
+ : client_(nullptr),
+ channel_id_(0),
service_id_(0),
method_id_(0),
handler_(nullptr),
@@ -63,7 +80,7 @@
void Cancel();
protected:
- constexpr Channel& channel() const { return *channel_; }
+ constexpr uint32_t channel_id() const { return channel_id_; }
constexpr uint32_t service_id() const { return service_id_; }
constexpr uint32_t method_id() const { return method_id_; }
@@ -82,7 +99,8 @@
Packet NewPacket(PacketType type,
std::span<const std::byte> payload = {}) const;
- Channel* channel_;
+ Client* client_;
+ uint32_t channel_id_;
uint32_t service_id_;
uint32_t method_id_;
Channel::OutputBuffer request_;
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 7630293..ee8b4b4 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -93,7 +93,8 @@
return param
-def _generate_code_for_client_method(method: ProtoServiceMethod,
+def _generate_code_for_client_method(class_name: str,
+ method: ProtoServiceMethod,
output: OutputFile) -> None:
"""Outputs client code for a single RPC method."""
@@ -133,11 +134,17 @@
call_alias = f'{method.name()}Call'
+ moved_functions = list(f'std::move({function.name})'
+ for function in functions)
+
output.write_line()
output.write_line(
f'using {call_alias} = {RPC_NAMESPACE}::NanopbClientCall<')
output.write_line(f' {callbacks}<{res}>>;')
output.write_line()
+
+ # TODO(frolv): Deprecate this channel-based API.
+ # ======== Deprecated API ========
output.write_line(f'static {call_alias} {method.name()}(')
with output.indent(4):
output.write_line(f'{RPC_NAMESPACE}::Channel& channel,')
@@ -156,17 +163,78 @@
output.write_line('kServiceId,')
output.write_line(
f'0x{method_id:08x}, // Hash of "{method.name()}"')
-
- moved_functions = (f'std::move({function.name})'
- for function in functions)
output.write_line(f'{callbacks}({", ".join(moved_functions)}),')
-
output.write_line(f'{req}_fields,')
output.write_line(f'{res}_fields);')
output.write_line('call.SendRequest(&request);')
output.write_line('return call;')
output.write_line('}')
+ output.write_line()
+
+ # ======== End deprecated API ========
+
+ output.write_line(f'{call_alias} {method.name()}(')
+ with output.indent(4):
+ output.write_line(f'const {req}& request,')
+
+ # Write out each of the callback functions for the method type.
+ for i, function in enumerate(functions):
+ if i == len(functions) - 1:
+ output.write_line(f'{function}) {{')
+ else:
+ output.write_line(f'{function},')
+
+ with output.indent():
+ output.write_line()
+ output.write_line(f'{call_alias} call(client_,')
+ with output.indent(len(call_alias) + 6):
+ output.write_line('default_channel_id_,')
+ output.write_line('kServiceId,')
+ output.write_line(
+ f'0x{method_id:08x}, // Hash of "{method.name()}"')
+ output.write_line(f'{callbacks}({", ".join(moved_functions)}),')
+ output.write_line(f'{req}_fields,')
+ output.write_line(f'{res}_fields);')
+
+ # Unary and server streaming RPCs send the initial request immediately.
+ if method.type() in (ProtoServiceMethod.Type.UNARY,
+ ProtoServiceMethod.Type.SERVER_STREAMING):
+ output.write_line()
+ output.write_line('if (::pw::Status status = '
+ 'call.SendRequest(&request); !status.ok()) {')
+ with output.indent():
+ output.write_line('call.callbacks().InvokeRpcError(status);')
+ output.write_line('}')
+ output.write_line()
+
+ output.write_line('return call;')
+
+ output.write_line('}')
+ output.write_line()
+
+ # Generate a static wrapper method that instantiates a client.
+ output.write_line(f'static {call_alias} {method.name()}(')
+
+ with output.indent(4):
+ output.write_line(f'{RPC_NAMESPACE}::Client& client,')
+ output.write_line('uint32_t channel_id,')
+ output.write_line(f'const {req}& request,')
+
+ # Write out each of the callback functions for the method type.
+ for i, function in enumerate(functions):
+ if i == len(functions) - 1:
+ output.write_line(f'{function}) {{')
+ else:
+ output.write_line(f'{function},')
+
+ with output.indent():
+ output.write_line(
+ f'return {class_name}(client, channel_id).{method.name()}(')
+ output.write_line(f' request, {", ".join(moved_functions)});')
+
+ output.write_line('}')
+ output.write_line()
def _generate_code_for_client(service: ProtoService, root: ProtoNode,
@@ -180,10 +248,23 @@
output.write_line(' public:')
with output.indent():
- output.write_line(f'{class_name}() = delete;')
+ output.write_line(
+ f'constexpr {class_name}({RPC_NAMESPACE}::Client& client,'
+ ' uint32_t default_channel_id)')
+ output.write_line(
+ ' : client_(&client), default_channel_id_(default_channel_id) {}'
+ )
+
+ output.write_line()
+ output.write_line(f'{class_name}(const {class_name}&) = default;')
+ output.write_line(f'{class_name}({class_name}&&) = default;')
+ output.write_line(
+ f'{class_name}& operator=(const {class_name}&) = default;')
+ output.write_line(
+ f'{class_name}& operator=({class_name}&&) = default;')
for method in service.methods():
- _generate_code_for_client_method(method, output)
+ _generate_code_for_client_method(class_name, method, output)
service_name_hash = pw_rpc.ids.calculate(service.proto_path())
output.write_line('\n private:')
@@ -194,14 +275,19 @@
f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
)
+ output.write_line()
+ output.write_line(f'{RPC_NAMESPACE}::Client* client_;')
+ output.write_line('uint32_t default_channel_id_;')
+
output.write_line('};')
output.write_line('\n} // namespace nanopb\n')
def includes(proto_file, unused_package: ProtoNode) -> Iterator[str]:
- yield '#include "pw_rpc/nanopb/internal/method_union.h"'
+ yield '#include "pw_rpc/client.h"'
yield '#include "pw_rpc/nanopb/client_call.h"'
+ yield '#include "pw_rpc/nanopb/internal/method_union.h"'
# Include the corresponding nanopb header file for this proto file, in which
# the file's messages and enums are generated. All other files imported from