pw_rpc: Support cancelling streaming RPCs

Introduce new PacketType::CANCEL. Clients send CANCEL to cancel an
ongoing streaming RPC. Servers send CANCEL packets to indicate or
confirm that an RPC was cancelled.

Change-Id: I72e7c5c4859f0c4262565a8c963946e00ec4dd37
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/12462
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 62fbe3c..60a9c5e 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -146,9 +146,8 @@
     ":channel_test",
     ":packet_test",
     ":server_test",
-    "nanopb:codegen_test",
-    "nanopb:method_test",
   ]
+  group_deps = [ "nanopb:tests" ]
 }
 
 # RPC server for tests only. A mock method implementation is used.
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 31da1d2..67abc6f 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -43,17 +43,23 @@
   return *this;
 }
 
+uint32_t BaseServerWriter::method_id() const { return call_.method().id(); }
+
 void BaseServerWriter::Finish() {
   if (!open()) {
     return;
   }
 
   call_.server().RemoveWriter(*this);
-
-  // TODO(hepler): Send a control packet indicating that the stream has
-  // terminated.
-
   state_ = kClosed;
+
+  // Send a control packet indicating that the stream has terminated.
+  auto response = call_.channel().AcquireBuffer();
+  call_.channel().Send(response,
+                       Packet(PacketType::CANCEL,
+                              call_.channel().id(),
+                              call_.service().id(),
+                              method().id()));
 }
 
 std::span<std::byte> BaseServerWriter::AcquirePayloadBuffer() {
@@ -62,7 +68,7 @@
   }
 
   response_ = call_.channel().AcquireBuffer();
-  return response_.payload(packet());
+  return response_.payload(RpcPacket());
 }
 
 Status BaseServerWriter::ReleasePayloadBuffer(
@@ -70,10 +76,10 @@
   if (!open()) {
     return Status::FAILED_PRECONDITION;
   }
-  return call_.channel().Send(response_, packet(payload));
+  return call_.channel().Send(response_, RpcPacket(payload));
 }
 
-Packet BaseServerWriter::packet(std::span<const std::byte> payload) const {
+Packet BaseServerWriter::RpcPacket(std::span<const std::byte> payload) const {
   return Packet(PacketType::RPC,
                 call_.channel().id(),
                 call_.service().id(),
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index b10620e..04d37db 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -106,6 +106,21 @@
   EXPECT_TRUE(writers.empty());
 }
 
+TEST(ServerWriter, Finish_SendsCancellationPacket) {
+  ServerContextForTest<TestService> context;
+  FakeServerWriter writer(context.get());
+
+  writer.Finish();
+
+  Packet packet = Packet::FromBuffer(context.output().sent_packet());
+  EXPECT_EQ(packet.type(), PacketType::CANCEL);
+  EXPECT_EQ(packet.channel_id(), context.kChannelId);
+  EXPECT_EQ(packet.service_id(), context.kServiceId);
+  EXPECT_EQ(packet.method_id(), context.get().method().id());
+  EXPECT_TRUE(packet.payload().empty());
+  EXPECT_EQ(packet.status(), Status::OK);
+}
+
 TEST(ServerWriter, Close) {
   ServerContextForTest<TestService> context;
   FakeServerWriter writer(context.get());
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index f5b4689..75ed96f 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -36,14 +36,11 @@
   visibility = [ "../*" ]
 }
 
-pw_test("method_test") {
-  deps = [
-    "..:nanopb_server",
-    "..:test_protos_nanopb",
-    "..:test_utils_nanopb_server",
+pw_test_group("tests") {
+  tests = [
+    ":codegen_test",
+    ":method_test",
   ]
-  sources = [ "method_test.cc" ]
-  enable_if = dir_pw_third_party_nanopb != ""
 }
 
 pw_test("codegen_test") {
@@ -54,3 +51,13 @@
   sources = [ "codegen_test.cc" ]
   enable_if = dir_pw_third_party_nanopb != ""
 }
+
+pw_test("method_test") {
+  deps = [
+    "..:nanopb_server",
+    "..:test_protos_nanopb",
+    "..:test_utils_nanopb_server",
+  ]
+  sources = [ "method_test.cc" ]
+  enable_if = dir_pw_third_party_nanopb != ""
+}
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h
index ecae939..1b95082 100644
--- a/pw_rpc/public/pw_rpc/channel.h
+++ b/pw_rpc/public/pw_rpc/channel.h
@@ -20,11 +20,6 @@
 #include "pw_status/status.h"
 
 namespace pw::rpc {
-namespace internal {
-
-class BaseServerWriter;
-
-}  // namespace internal
 
 class ChannelOutput {
  public:
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 47c6858..745b615 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -20,6 +20,7 @@
 #include "pw_containers/intrusive_list.h"
 #include "pw_rpc/internal/call.h"
 #include "pw_rpc/internal/channel.h"
+#include "pw_rpc/internal/service.h"
 
 namespace pw::rpc::internal {
 
@@ -46,6 +47,10 @@
   // True if the ServerWriter is active and ready to send responses.
   bool open() const { return state_ == kOpen; }
 
+  uint32_t channel_id() const { return call_.channel().id(); }
+  uint32_t service_id() const { return call_.service().id(); }
+  uint32_t method_id() const;
+
   // Closes the ServerWriter, if it is open.
   void Finish();
 
@@ -59,7 +64,7 @@
   Status ReleasePayloadBuffer(std::span<const std::byte> payload);
 
  private:
-  Packet packet(std::span<const std::byte> payload = {}) const;
+  Packet RpcPacket(std::span<const std::byte> payload = {}) const;
 
   ServerCall call_;
   Channel::OutputBuffer response_;
diff --git a/pw_rpc/public/pw_rpc/internal/service.h b/pw_rpc/public/pw_rpc/internal/service.h
index 9a82b57..957de33 100644
--- a/pw_rpc/public/pw_rpc/internal/service.h
+++ b/pw_rpc/public/pw_rpc/internal/service.h
@@ -18,10 +18,11 @@
 #include <utility>
 
 #include "pw_containers/intrusive_list.h"
-#include "pw_rpc/internal/method.h"
 
 namespace pw::rpc::internal {
 
+class Method;
+
 // Base class for all RPC services. This cannot be instantiated directly; use a
 // generated subclass instead.
 class Service : public IntrusiveList<Service>::Item {
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 5b439a1..9cd5695 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -48,6 +48,11 @@
   IntrusiveList<internal::BaseServerWriter>& writers() { return writers_; }
 
  private:
+  void HandleRpcPacket(const internal::Packet& request,
+                       internal::Channel& channel);
+
+  void HandleCancelPacket(const internal::Packet& request);
+
   void InvokeMethod(const internal::Packet& request,
                     Channel& channel,
                     internal::Packet& response,
diff --git a/pw_rpc/pw_rpc_protos/packet.proto b/pw_rpc/pw_rpc_protos/packet.proto
index fa04692..cd03194 100644
--- a/pw_rpc/pw_rpc_protos/packet.proto
+++ b/pw_rpc/pw_rpc_protos/packet.proto
@@ -15,7 +15,16 @@
 
 package pw.rpc.internal;
 
-enum PacketType { RPC = 0; };
+enum PacketType {
+  // RPC packets correspond with an RPC request or response.
+  RPC = 0;
+
+  // CANCEL packets indicate cancellation of an ongoing streaming RPC. They are
+  // sent by the client to request cancellation and by the server to signal that
+  // an RPC was cancelled. The server sends a CANCEL packet when an RPC is
+  // cancelled server-side and in response to a client's request.
+  CANCEL = 1;
+}
 
 message RpcPacket {
   // The type of packet. Either a general RPC packet or a specific control
@@ -37,4 +46,4 @@
 
   // RPC response status code.
   uint32 status = 6;
-};
+}
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index f9f3f05..3ea7d4d 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -164,6 +164,8 @@
     output.write_line('#pragma once\n')
     output.write_line('#include <cstddef>')
     output.write_line('#include <cstdint>\n')
+    output.write_line('#include "pw_rpc/server_context.h"')
+    output.write_line('#include "pw_rpc/internal/method.h"')
     output.write_line('#include "pw_rpc/internal/service.h"')
 
     # Include the corresponding nanopb header file for this proto file, in which
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 20afc77..7c5866b 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -39,22 +39,16 @@
 
 void Server::ProcessPacket(std::span<const byte> data,
                            ChannelOutput& interface) {
+  // TODO(hepler): Update the packet parsing code to report when decoding fails.
   Packet packet = Packet::FromBuffer(data);
-  if (packet.is_control()) {
-    // TODO(frolv): Handle control packets.
-    return;
-  }
 
   if (packet.channel_id() == Channel::kUnassignedChannelId ||
       packet.service_id() == 0 || packet.method_id() == 0) {
     // Malformed packet; don't even try to process it.
-    PW_LOG_ERROR("Received incomplete RPC packet on interface %s",
-                 interface.name());
+    PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
     return;
   }
 
-  Packet response(PacketType::RPC);
-
   internal::Channel* channel = FindChannel(packet.channel_id());
   if (channel == nullptr) {
     // If the requested channel doesn't exist, try to dynamically assign one.
@@ -64,6 +58,10 @@
       // the server cannot process the request. The channel_id in the response
       // is not set, to allow clients to detect this error case.
       internal::Channel temp_channel(packet.channel_id(), &interface);
+
+      // TODO(hepler): Add a new PacketType for errors like this, rather than
+      // using PacketType::RPC.
+      Packet response(PacketType::RPC);
       response.set_status(Status::RESOURCE_EXHAUSTED);
       auto response_buffer = temp_channel.AcquireBuffer();
       temp_channel.Send(response_buffer, response);
@@ -71,12 +69,40 @@
     }
   }
 
-  response.set_channel_id(channel->id());
-  auto response_buffer = channel->AcquireBuffer();
+  switch (packet.type()) {
+    case PacketType::RPC:
+      HandleRpcPacket(packet, *channel);
+      break;
+    case PacketType::CANCEL:
+      HandleCancelPacket(packet);
+      break;
+  }
+}
+
+void Server::HandleRpcPacket(const internal::Packet& request,
+                             internal::Channel& channel) {
+  Packet response(PacketType::RPC);
+
+  response.set_channel_id(channel.id());
+  auto response_buffer = channel.AcquireBuffer();
 
   // Invoke the method with matching service and method IDs, if any.
-  InvokeMethod(packet, *channel, response, response_buffer.payload(response));
-  channel->Send(response_buffer, response);
+  InvokeMethod(request, channel, response, response_buffer.payload(response));
+  channel.Send(response_buffer, response);
+}
+
+void Server::HandleCancelPacket(const internal::Packet& request) {
+  auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
+    return w.channel_id() == request.channel_id() &&
+           w.service_id() == request.service_id() &&
+           w.method_id() == request.method_id();
+  });
+
+  if (writer == writers_.end()) {
+    PW_LOG_WARN("Received CANCEL packet for unknown method");
+  } else {
+    writer->Finish();
+  }
 }
 
 void Server::InvokeMethod(const Packet& request,
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 7082772..47b62e5 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -71,11 +71,6 @@
     server_.RegisterService(service_);
   }
 
-  TestOutput<128> output_;
-  std::array<Channel, 3> channels_;
-  Server server_;
-  TestService service_;
-
   std::span<const byte> EncodeRequest(
       PacketType type,
       uint32_t channel_id,
@@ -88,6 +83,11 @@
     return std::span(request_buffer_, sws.size());
   }
 
+  TestOutput<128> output_;
+  std::array<Channel, 3> channels_;
+  Server server_;
+  TestService service_;
+
  private:
   byte request_buffer_[64];
 };
@@ -108,6 +108,7 @@
   server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 100), output_);
 
   Packet packet = Packet::FromBuffer(output_.sent_packet());
+  EXPECT_EQ(packet.type(), PacketType::RPC);
   EXPECT_EQ(packet.channel_id(), 1u);
   EXPECT_EQ(packet.service_id(), 42u);
   EXPECT_EQ(packet.method_id(), 100u);
@@ -131,6 +132,15 @@
   EXPECT_EQ(std::memcmp(packet.payload().data(), resp, sizeof(resp)), 0);
 }
 
+TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
+  server_.ProcessPacket(EncodeRequest(PacketType::RPC, 0, 42, 101), output_);
+  server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 0, 101), output_);
+  server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 0), output_);
+
+  EXPECT_EQ(0u, service_.method(100).last_channel_id());
+  EXPECT_EQ(0u, service_.method(200).last_channel_id());
+}
+
 TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
   server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 101), output_);
 
@@ -142,6 +152,7 @@
   server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 27), output_);
 
   Packet packet = Packet::FromBuffer(output_.sent_packet());
+  EXPECT_EQ(packet.type(), PacketType::RPC);
   EXPECT_EQ(packet.channel_id(), 1u);
   EXPECT_EQ(packet.service_id(), 42u);
   EXPECT_EQ(packet.method_id(), 0u);  // No method ID 27
@@ -185,5 +196,44 @@
   EXPECT_EQ(packet.method_id(), 0u);
 }
 
+TEST_F(BasicServer, ProcessPacket_Cancel_ClosesServerWriter) {
+  // Set up a fake ServerWriter representing an ongoing RPC.
+  internal::ServerCall call(static_cast<internal::Server&>(server_),
+                            static_cast<internal::Channel&>(channels_[0]),
+                            service_,
+                            service_.method(100));
+  internal::BaseServerWriter writer(call);
+  ASSERT_TRUE(writer.open());
+
+  server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_);
+
+  EXPECT_FALSE(writer.open());
+
+  Packet packet = Packet::FromBuffer(output_.sent_packet());
+  EXPECT_EQ(packet.type(), PacketType::CANCEL);
+  EXPECT_EQ(packet.channel_id(), 1u);
+  EXPECT_EQ(packet.service_id(), 42u);
+  EXPECT_EQ(packet.method_id(), 100u);
+  EXPECT_TRUE(packet.payload().empty());
+  EXPECT_EQ(packet.status(), Status::OK);
+}
+
+TEST_F(BasicServer, ProcessPacket_Cancel_UnknownIdIsIgnored) {
+  internal::ServerCall call(static_cast<internal::Server&>(server_),
+                            static_cast<internal::Channel&>(channels_[0]),
+                            service_,
+                            service_.method(100));
+  internal::BaseServerWriter writer(call);
+  ASSERT_TRUE(writer.open());
+
+  // Send packets with incorrect channel, service, and method ID.
+  server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 2, 42, 100), output_);
+  server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 43, 100), output_);
+  server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 101), output_);
+
+  EXPECT_TRUE(writer.open());
+  EXPECT_TRUE(output_.sent_packet().empty());
+}
+
 }  // namespace
 }  // namespace pw::rpc
diff --git a/pw_rpc/service.cc b/pw_rpc/service.cc
index b0ddd23..3107c92 100644
--- a/pw_rpc/service.cc
+++ b/pw_rpc/service.cc
@@ -16,6 +16,8 @@
 
 #include <type_traits>
 
+#include "pw_rpc/internal/method.h"
+
 namespace pw::rpc::internal {
 
 const Method* Service::FindMethod(uint32_t method_id) const {