pw_rpc: Support cancelling streaming RPCs
Introduce new PacketType::CANCEL. Clients send CANCEL to cancel an
ongoing streaming RPC. Servers send CANCEL packets to indicate or
confirm that an RPC was cancelled.
Change-Id: I72e7c5c4859f0c4262565a8c963946e00ec4dd37
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/12462
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 62fbe3c..60a9c5e 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -146,9 +146,8 @@
":channel_test",
":packet_test",
":server_test",
- "nanopb:codegen_test",
- "nanopb:method_test",
]
+ group_deps = [ "nanopb:tests" ]
}
# RPC server for tests only. A mock method implementation is used.
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 31da1d2..67abc6f 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -43,17 +43,23 @@
return *this;
}
+uint32_t BaseServerWriter::method_id() const { return call_.method().id(); }
+
void BaseServerWriter::Finish() {
if (!open()) {
return;
}
call_.server().RemoveWriter(*this);
-
- // TODO(hepler): Send a control packet indicating that the stream has
- // terminated.
-
state_ = kClosed;
+
+ // Send a control packet indicating that the stream has terminated.
+ auto response = call_.channel().AcquireBuffer();
+ call_.channel().Send(response,
+ Packet(PacketType::CANCEL,
+ call_.channel().id(),
+ call_.service().id(),
+ method().id()));
}
std::span<std::byte> BaseServerWriter::AcquirePayloadBuffer() {
@@ -62,7 +68,7 @@
}
response_ = call_.channel().AcquireBuffer();
- return response_.payload(packet());
+ return response_.payload(RpcPacket());
}
Status BaseServerWriter::ReleasePayloadBuffer(
@@ -70,10 +76,10 @@
if (!open()) {
return Status::FAILED_PRECONDITION;
}
- return call_.channel().Send(response_, packet(payload));
+ return call_.channel().Send(response_, RpcPacket(payload));
}
-Packet BaseServerWriter::packet(std::span<const std::byte> payload) const {
+Packet BaseServerWriter::RpcPacket(std::span<const std::byte> payload) const {
return Packet(PacketType::RPC,
call_.channel().id(),
call_.service().id(),
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index b10620e..04d37db 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -106,6 +106,21 @@
EXPECT_TRUE(writers.empty());
}
+TEST(ServerWriter, Finish_SendsCancellationPacket) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ writer.Finish();
+
+ Packet packet = Packet::FromBuffer(context.output().sent_packet());
+ EXPECT_EQ(packet.type(), PacketType::CANCEL);
+ EXPECT_EQ(packet.channel_id(), context.kChannelId);
+ EXPECT_EQ(packet.service_id(), context.kServiceId);
+ EXPECT_EQ(packet.method_id(), context.get().method().id());
+ EXPECT_TRUE(packet.payload().empty());
+ EXPECT_EQ(packet.status(), Status::OK);
+}
+
TEST(ServerWriter, Close) {
ServerContextForTest<TestService> context;
FakeServerWriter writer(context.get());
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index f5b4689..75ed96f 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -36,14 +36,11 @@
visibility = [ "../*" ]
}
-pw_test("method_test") {
- deps = [
- "..:nanopb_server",
- "..:test_protos_nanopb",
- "..:test_utils_nanopb_server",
+pw_test_group("tests") {
+ tests = [
+ ":codegen_test",
+ ":method_test",
]
- sources = [ "method_test.cc" ]
- enable_if = dir_pw_third_party_nanopb != ""
}
pw_test("codegen_test") {
@@ -54,3 +51,13 @@
sources = [ "codegen_test.cc" ]
enable_if = dir_pw_third_party_nanopb != ""
}
+
+pw_test("method_test") {
+ deps = [
+ "..:nanopb_server",
+ "..:test_protos_nanopb",
+ "..:test_utils_nanopb_server",
+ ]
+ sources = [ "method_test.cc" ]
+ enable_if = dir_pw_third_party_nanopb != ""
+}
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h
index ecae939..1b95082 100644
--- a/pw_rpc/public/pw_rpc/channel.h
+++ b/pw_rpc/public/pw_rpc/channel.h
@@ -20,11 +20,6 @@
#include "pw_status/status.h"
namespace pw::rpc {
-namespace internal {
-
-class BaseServerWriter;
-
-} // namespace internal
class ChannelOutput {
public:
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 47c6858..745b615 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -20,6 +20,7 @@
#include "pw_containers/intrusive_list.h"
#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/channel.h"
+#include "pw_rpc/internal/service.h"
namespace pw::rpc::internal {
@@ -46,6 +47,10 @@
// True if the ServerWriter is active and ready to send responses.
bool open() const { return state_ == kOpen; }
+ uint32_t channel_id() const { return call_.channel().id(); }
+ uint32_t service_id() const { return call_.service().id(); }
+ uint32_t method_id() const;
+
// Closes the ServerWriter, if it is open.
void Finish();
@@ -59,7 +64,7 @@
Status ReleasePayloadBuffer(std::span<const std::byte> payload);
private:
- Packet packet(std::span<const std::byte> payload = {}) const;
+ Packet RpcPacket(std::span<const std::byte> payload = {}) const;
ServerCall call_;
Channel::OutputBuffer response_;
diff --git a/pw_rpc/public/pw_rpc/internal/service.h b/pw_rpc/public/pw_rpc/internal/service.h
index 9a82b57..957de33 100644
--- a/pw_rpc/public/pw_rpc/internal/service.h
+++ b/pw_rpc/public/pw_rpc/internal/service.h
@@ -18,10 +18,11 @@
#include <utility>
#include "pw_containers/intrusive_list.h"
-#include "pw_rpc/internal/method.h"
namespace pw::rpc::internal {
+class Method;
+
// Base class for all RPC services. This cannot be instantiated directly; use a
// generated subclass instead.
class Service : public IntrusiveList<Service>::Item {
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 5b439a1..9cd5695 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -48,6 +48,11 @@
IntrusiveList<internal::BaseServerWriter>& writers() { return writers_; }
private:
+ void HandleRpcPacket(const internal::Packet& request,
+ internal::Channel& channel);
+
+ void HandleCancelPacket(const internal::Packet& request);
+
void InvokeMethod(const internal::Packet& request,
Channel& channel,
internal::Packet& response,
diff --git a/pw_rpc/pw_rpc_protos/packet.proto b/pw_rpc/pw_rpc_protos/packet.proto
index fa04692..cd03194 100644
--- a/pw_rpc/pw_rpc_protos/packet.proto
+++ b/pw_rpc/pw_rpc_protos/packet.proto
@@ -15,7 +15,16 @@
package pw.rpc.internal;
-enum PacketType { RPC = 0; };
+enum PacketType {
+ // RPC packets correspond with an RPC request or response.
+ RPC = 0;
+
+ // CANCEL packets indicate cancellation of an ongoing streaming RPC. They are
+ // sent by the client to request cancellation and by the server to signal that
+ // an RPC was cancelled. The server sends a CANCEL packet when an RPC is
+ // cancelled server-side and in response to a client's request.
+ CANCEL = 1;
+}
message RpcPacket {
// The type of packet. Either a general RPC packet or a specific control
@@ -37,4 +46,4 @@
// RPC response status code.
uint32 status = 6;
-};
+}
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index f9f3f05..3ea7d4d 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -164,6 +164,8 @@
output.write_line('#pragma once\n')
output.write_line('#include <cstddef>')
output.write_line('#include <cstdint>\n')
+ output.write_line('#include "pw_rpc/server_context.h"')
+ output.write_line('#include "pw_rpc/internal/method.h"')
output.write_line('#include "pw_rpc/internal/service.h"')
# Include the corresponding nanopb header file for this proto file, in which
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 20afc77..7c5866b 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -39,22 +39,16 @@
void Server::ProcessPacket(std::span<const byte> data,
ChannelOutput& interface) {
+ // TODO(hepler): Update the packet parsing code to report when decoding fails.
Packet packet = Packet::FromBuffer(data);
- if (packet.is_control()) {
- // TODO(frolv): Handle control packets.
- return;
- }
if (packet.channel_id() == Channel::kUnassignedChannelId ||
packet.service_id() == 0 || packet.method_id() == 0) {
// Malformed packet; don't even try to process it.
- PW_LOG_ERROR("Received incomplete RPC packet on interface %s",
- interface.name());
+ PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
return;
}
- Packet response(PacketType::RPC);
-
internal::Channel* channel = FindChannel(packet.channel_id());
if (channel == nullptr) {
// If the requested channel doesn't exist, try to dynamically assign one.
@@ -64,6 +58,10 @@
// the server cannot process the request. The channel_id in the response
// is not set, to allow clients to detect this error case.
internal::Channel temp_channel(packet.channel_id(), &interface);
+
+ // TODO(hepler): Add a new PacketType for errors like this, rather than
+ // using PacketType::RPC.
+ Packet response(PacketType::RPC);
response.set_status(Status::RESOURCE_EXHAUSTED);
auto response_buffer = temp_channel.AcquireBuffer();
temp_channel.Send(response_buffer, response);
@@ -71,12 +69,40 @@
}
}
- response.set_channel_id(channel->id());
- auto response_buffer = channel->AcquireBuffer();
+ switch (packet.type()) {
+ case PacketType::RPC:
+ HandleRpcPacket(packet, *channel);
+ break;
+ case PacketType::CANCEL:
+ HandleCancelPacket(packet);
+ break;
+ }
+}
+
+void Server::HandleRpcPacket(const internal::Packet& request,
+ internal::Channel& channel) {
+ Packet response(PacketType::RPC);
+
+ response.set_channel_id(channel.id());
+ auto response_buffer = channel.AcquireBuffer();
// Invoke the method with matching service and method IDs, if any.
- InvokeMethod(packet, *channel, response, response_buffer.payload(response));
- channel->Send(response_buffer, response);
+ InvokeMethod(request, channel, response, response_buffer.payload(response));
+ channel.Send(response_buffer, response);
+}
+
+void Server::HandleCancelPacket(const internal::Packet& request) {
+ auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
+ return w.channel_id() == request.channel_id() &&
+ w.service_id() == request.service_id() &&
+ w.method_id() == request.method_id();
+ });
+
+ if (writer == writers_.end()) {
+ PW_LOG_WARN("Received CANCEL packet for unknown method");
+ } else {
+ writer->Finish();
+ }
}
void Server::InvokeMethod(const Packet& request,
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 7082772..47b62e5 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -71,11 +71,6 @@
server_.RegisterService(service_);
}
- TestOutput<128> output_;
- std::array<Channel, 3> channels_;
- Server server_;
- TestService service_;
-
std::span<const byte> EncodeRequest(
PacketType type,
uint32_t channel_id,
@@ -88,6 +83,11 @@
return std::span(request_buffer_, sws.size());
}
+ TestOutput<128> output_;
+ std::array<Channel, 3> channels_;
+ Server server_;
+ TestService service_;
+
private:
byte request_buffer_[64];
};
@@ -108,6 +108,7 @@
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 100), output_);
Packet packet = Packet::FromBuffer(output_.sent_packet());
+ EXPECT_EQ(packet.type(), PacketType::RPC);
EXPECT_EQ(packet.channel_id(), 1u);
EXPECT_EQ(packet.service_id(), 42u);
EXPECT_EQ(packet.method_id(), 100u);
@@ -131,6 +132,15 @@
EXPECT_EQ(std::memcmp(packet.payload().data(), resp, sizeof(resp)), 0);
}
+TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 0, 42, 101), output_);
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 0, 101), output_);
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 0), output_);
+
+ EXPECT_EQ(0u, service_.method(100).last_channel_id());
+ EXPECT_EQ(0u, service_.method(200).last_channel_id());
+}
+
TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 101), output_);
@@ -142,6 +152,7 @@
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 27), output_);
Packet packet = Packet::FromBuffer(output_.sent_packet());
+ EXPECT_EQ(packet.type(), PacketType::RPC);
EXPECT_EQ(packet.channel_id(), 1u);
EXPECT_EQ(packet.service_id(), 42u);
EXPECT_EQ(packet.method_id(), 0u); // No method ID 27
@@ -185,5 +196,44 @@
EXPECT_EQ(packet.method_id(), 0u);
}
+TEST_F(BasicServer, ProcessPacket_Cancel_ClosesServerWriter) {
+ // Set up a fake ServerWriter representing an ongoing RPC.
+ internal::ServerCall call(static_cast<internal::Server&>(server_),
+ static_cast<internal::Channel&>(channels_[0]),
+ service_,
+ service_.method(100));
+ internal::BaseServerWriter writer(call);
+ ASSERT_TRUE(writer.open());
+
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_);
+
+ EXPECT_FALSE(writer.open());
+
+ Packet packet = Packet::FromBuffer(output_.sent_packet());
+ EXPECT_EQ(packet.type(), PacketType::CANCEL);
+ EXPECT_EQ(packet.channel_id(), 1u);
+ EXPECT_EQ(packet.service_id(), 42u);
+ EXPECT_EQ(packet.method_id(), 100u);
+ EXPECT_TRUE(packet.payload().empty());
+ EXPECT_EQ(packet.status(), Status::OK);
+}
+
+TEST_F(BasicServer, ProcessPacket_Cancel_UnknownIdIsIgnored) {
+ internal::ServerCall call(static_cast<internal::Server&>(server_),
+ static_cast<internal::Channel&>(channels_[0]),
+ service_,
+ service_.method(100));
+ internal::BaseServerWriter writer(call);
+ ASSERT_TRUE(writer.open());
+
+ // Send packets with incorrect channel, service, and method ID.
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 2, 42, 100), output_);
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 43, 100), output_);
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 101), output_);
+
+ EXPECT_TRUE(writer.open());
+ EXPECT_TRUE(output_.sent_packet().empty());
+}
+
} // namespace
} // namespace pw::rpc
diff --git a/pw_rpc/service.cc b/pw_rpc/service.cc
index b0ddd23..3107c92 100644
--- a/pw_rpc/service.cc
+++ b/pw_rpc/service.cc
@@ -16,6 +16,8 @@
#include <type_traits>
+#include "pw_rpc/internal/method.h"
+
namespace pw::rpc::internal {
const Method* Service::FindMethod(uint32_t method_id) const {