pw_rpc: Track ServerWriters in the Server
Register BaseServerWriters when created and unregister them when
Finish() is called. This will allow control packets to close
ServerWriters.
Change-Id: Ic3817b2da4a33f35144351267d7863cbb2bdb4a0
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/12361
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index bfe3fe7..4fa1f1f 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -16,10 +16,15 @@
#include "pw_rpc/internal/method.h"
#include "pw_rpc/internal/packet.h"
-#include "pw_rpc/server.h"
+#include "pw_rpc/internal/server.h"
namespace pw::rpc::internal {
+BaseServerWriter::BaseServerWriter(ServerCall& call)
+ : call_(call), state_(kOpen) {
+ call_.server().RegisterWriter(*this);
+}
+
BaseServerWriter& BaseServerWriter::operator=(BaseServerWriter&& other) {
call_ = std::move(other.call_);
response_ = std::move(other.response_);
@@ -34,6 +39,8 @@
return;
}
+ call_.server().RemoveWriter(*this);
+
// TODO(hepler): Send a control packet indicating that the stream has
// terminated.
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index c7b0480..8c1b914 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -58,7 +58,7 @@
class FakeServerWriter : public BaseServerWriter {
public:
- constexpr FakeServerWriter(ServerCall& context) : BaseServerWriter(context) {}
+ FakeServerWriter(ServerCall& context) : BaseServerWriter(context) {}
constexpr FakeServerWriter() = default;
@@ -77,6 +77,35 @@
EXPECT_FALSE(writer.open());
}
+TEST(ServerWriter, Construct_RegistersWithServer) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ auto& writers = context.server().writers();
+ EXPECT_FALSE(writers.empty());
+ auto it = std::find_if(
+ writers.begin(), writers.end(), [&](auto& w) { return &w == &writer; });
+ ASSERT_NE(it, writers.end());
+}
+
+TEST(ServerWriter, Destruct_RemovesFromServer) {
+ ServerContextForTest<TestService> context;
+ { FakeServerWriter writer(context.get()); }
+
+ auto& writers = context.server().writers();
+ EXPECT_TRUE(writers.empty());
+}
+
+TEST(ServerWriter, Finish_RemovesFromServer) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ writer.Finish();
+
+ auto& writers = context.server().writers();
+ EXPECT_TRUE(writers.empty());
+}
+
TEST(ServerWriter, Close) {
ServerContextForTest<TestService> context;
FakeServerWriter writer(context.get());
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 30d0883..413293c 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -16,6 +16,7 @@
#include <cstddef>
#include <utility>
+#include "pw_containers/intrusive_list.h"
#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/channel.h"
#include "pw_span/span.h"
@@ -28,17 +29,16 @@
// Internal ServerWriter base class. ServerWriters are used to stream responses.
// Implementations must provide a derived class that provides the interface for
// sending responses.
-//
-// TODO(hepler): ServerWriters need to be tracked by the server to support
-// cancelling / terminating ongoing streaming RPCs.
-class BaseServerWriter {
+class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item {
public:
- constexpr BaseServerWriter(ServerCall& call) : call_(call), state_(kOpen) {}
+ BaseServerWriter(ServerCall& call);
BaseServerWriter(const BaseServerWriter&) = delete;
BaseServerWriter(BaseServerWriter&& other) { *this = std::move(other); }
+ ~BaseServerWriter() { Finish(); }
+
BaseServerWriter& operator=(const BaseServerWriter&) = delete;
BaseServerWriter& operator=(BaseServerWriter&& other);
diff --git a/pw_rpc/public/pw_rpc/internal/server.h b/pw_rpc/public/pw_rpc/internal/server.h
index 53cdee0..e260636 100644
--- a/pw_rpc/public/pw_rpc/internal/server.h
+++ b/pw_rpc/public/pw_rpc/internal/server.h
@@ -20,6 +20,14 @@
class Server : public rpc::Server {
public:
Server() = delete;
+
+ void RegisterWriter(BaseServerWriter& writer) {
+ writers().push_front(writer);
+ }
+
+ void RemoveWriter(const BaseServerWriter& writer) {
+ writers().remove(writer);
+ }
};
} // namespace pw::rpc::internal
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 409e93c..1232a75 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -40,6 +40,9 @@
constexpr size_t channel_count() const { return channels_.size(); }
+ protected:
+ IntrusiveList<internal::BaseServerWriter>& writers() { return writers_; }
+
private:
void InvokeMethod(const internal::Packet& request,
Channel& channel,
@@ -51,6 +54,7 @@
span<internal::Channel> channels_;
IntrusiveList<internal::Service> services_;
+ IntrusiveList<internal::BaseServerWriter> writers_;
};
} // namespace pw::rpc
diff --git a/pw_rpc/pw_rpc_private/test_utils.h b/pw_rpc/pw_rpc_private/test_utils.h
index 57767ca..6dcd8a6 100644
--- a/pw_rpc/pw_rpc_private/test_utils.h
+++ b/pw_rpc/pw_rpc_private/test_utils.h
@@ -46,6 +46,12 @@
span<const std::byte> sent_packet_;
};
+// Version of the internal::Server with extra methods exposed for testing.
+class TestServer : public internal::Server {
+ public:
+ using internal::Server::writers;
+};
+
template <typename Service,
size_t output_buffer_size = 128,
uint32_t channel_id = 99,
@@ -80,6 +86,7 @@
internal::ServerCall& get() { return context_; }
const auto& output() const { return output_; }
+ TestServer& server() { return static_cast<TestServer&>(server_); }
private:
TestOutput<output_buffer_size> output_;