pw_rpc: Return BaseServerWriter::Finish() status
- Return the status of the Channel::Send() call in
BaseServerWriter::Finish().
- Assert if BaseServerWriter::AcquirePayloadBuffer() is called on a
closed server writer. This is an internal function and calling it on a
closed writer would be bug.
Change-Id: I8f10a3a400d3c3893057424e3675e8f4522cbda9
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/33126
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 478c67a..9d4423d 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -14,6 +14,7 @@
#include "pw_rpc/internal/base_server_writer.h"
+#include "pw_assert/assert.h"
#include "pw_rpc/internal/method.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/server.h"
@@ -45,9 +46,9 @@
uint32_t BaseServerWriter::method_id() const { return call_.method().id(); }
-void BaseServerWriter::Finish(Status status) {
+Status BaseServerWriter::Finish(Status status) {
if (!open()) {
- return;
+ return Status::FailedPrecondition();
}
// If the ServerWriter implementer or user forgets to release an acquired
@@ -59,18 +60,16 @@
Close();
// Send a control packet indicating that the stream (and RPC) has terminated.
- call_.channel().Send(Packet(PacketType::SERVER_STREAM_END,
- call_.channel().id(),
- call_.service().id(),
- method().id(),
- {},
- status));
+ return call_.channel().Send(Packet(PacketType::SERVER_STREAM_END,
+ call_.channel().id(),
+ call_.service().id(),
+ method().id(),
+ {},
+ status));
}
std::span<std::byte> BaseServerWriter::AcquirePayloadBuffer() {
- if (!open()) {
- return {};
- }
+ PW_DCHECK(open());
// Only allow having one active buffer at a time.
if (response_.empty()) {
@@ -82,17 +81,12 @@
Status BaseServerWriter::ReleasePayloadBuffer(
std::span<const std::byte> payload) {
- if (!open()) {
- return Status::FailedPrecondition();
- }
+ PW_DCHECK(open());
return call_.channel().Send(response_, ResponsePacket(payload));
}
Status BaseServerWriter::ReleasePayloadBuffer() {
- if (!open()) {
- return Status::FailedPrecondition();
- }
-
+ PW_DCHECK(open());
call_.channel().Release(response_);
return OkStatus();
}
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index 4314cff..06632ae 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -104,7 +104,7 @@
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
- writer.Finish();
+ EXPECT_EQ(OkStatus(), writer.Finish());
auto& writers = context.server().writers();
EXPECT_TRUE(writers.empty());
@@ -114,7 +114,7 @@
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
- writer.Finish();
+ EXPECT_EQ(OkStatus(), writer.Finish());
const Packet& packet = context.output().sent_packet();
EXPECT_EQ(packet.type(), PacketType::SERVER_STREAM_END);
@@ -125,13 +125,22 @@
EXPECT_EQ(packet.status(), OkStatus());
}
+TEST(ServerWriter, Finish_ReturnsStatusFromChannelSend) {
+ ServerContextForTest<TestService> context(TestService::method.method());
+ FakeServerWriter writer(context.get());
+ context.output().set_send_status(Status::Unauthenticated());
+
+ EXPECT_EQ(Status::Unauthenticated(), writer.Finish());
+}
+
TEST(ServerWriter, Close) {
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
ASSERT_TRUE(writer.open());
- writer.Finish();
+ EXPECT_EQ(OkStatus(), writer.Finish());
EXPECT_FALSE(writer.open());
+ EXPECT_EQ(Status::FailedPrecondition(), writer.Finish());
}
TEST(ServerWriter, Close_ReleasesBuffer) {
@@ -142,7 +151,7 @@
auto buffer = writer.PayloadBuffer();
buffer[0] = std::byte{0};
EXPECT_FALSE(writer.output_buffer().empty());
- writer.Finish();
+ EXPECT_EQ(OkStatus(), writer.Finish());
EXPECT_FALSE(writer.open());
EXPECT_TRUE(writer.output_buffer().empty());
}
@@ -165,14 +174,12 @@
encoded, context.output().sent_data().data(), result.value().size()));
}
-TEST(ServerWriter, Closed_IgnoresPacket) {
+TEST(ServerWriter, Closed_IgnoresFinish) {
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
- writer.Finish();
-
- constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
- EXPECT_EQ(Status::FailedPrecondition(), writer.Write(data));
+ EXPECT_EQ(OkStatus(), writer.Finish());
+ EXPECT_EQ(Status::FailedPrecondition(), writer.Finish());
}
} // namespace
diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc
index e0e4a76..a9055bb 100644
--- a/pw_rpc/nanopb/nanopb_method_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_test.cc
@@ -206,10 +206,26 @@
method.Invoke(context.get(), context.packet({}));
- last_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_writer.Finish());
EXPECT_TRUE(last_writer.Write({.value = 100}).IsFailedPrecondition());
}
+TEST(NanopbMethod, ServerWriter_WriteAfterMoved_ReturnsFailedPrecondition) {
+ const NanopbMethod& method =
+ std::get<2>(FakeService::kMethods).nanopb_method();
+ ServerContextForTest<FakeService> context(method);
+
+ method.Invoke(context.get(), context.packet({}));
+ ServerWriter<pw_rpc_test_TestResponse> new_writer = std::move(last_writer);
+
+ EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100}));
+
+ EXPECT_EQ(Status::FailedPrecondition(), last_writer.Write({.value = 100}));
+ EXPECT_EQ(Status::FailedPrecondition(), last_writer.Finish());
+
+ EXPECT_EQ(OkStatus(), new_writer.Finish());
+}
+
TEST(NanopbMethod,
ServerStreamingRpc_ServerWriterBufferTooSmall_InternalError) {
const NanopbMethod& method =
diff --git a/pw_rpc/nanopb/nanopb_method_union_test.cc b/pw_rpc/nanopb/nanopb_method_union_test.cc
index 5df140d..ab754e4 100644
--- a/pw_rpc/nanopb/nanopb_method_union_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_union_test.cc
@@ -99,7 +99,7 @@
method.Invoke(context.get(), context.packet(request));
EXPECT_TRUE(last_raw_writer.open());
- last_raw_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_raw_writer.Finish());
EXPECT_EQ(context.output().sent_packet().type(),
PacketType::SERVER_STREAM_END);
}
@@ -141,7 +141,7 @@
EXPECT_EQ(555, last_request.integer);
EXPECT_TRUE(last_writer.open());
- last_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_writer.Finish());
EXPECT_EQ(context.output().sent_packet().type(),
PacketType::SERVER_STREAM_END);
}
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 67bf7aa..895c1b8 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -59,7 +59,7 @@
uint32_t method_id() const;
// Closes the ServerWriter, if it is open.
- void Finish(Status status = OkStatus());
+ Status Finish(Status status = OkStatus());
protected:
constexpr BaseServerWriter() : state_{kClosed} {}
@@ -70,9 +70,12 @@
constexpr const Channel::OutputBuffer& buffer() const { return response_; }
+ // Acquires a buffer into which to write a payload. The BaseServerWriter MUST
+ // be open when this is called!
std::span<std::byte> AcquirePayloadBuffer();
- // Releases the buffer, sending a packet with the specified payload.
+ // Releases the buffer, sending a packet with the specified payload. The
+ // BaseServerWriter MUST be open when this is called!
Status ReleasePayloadBuffer(std::span<const std::byte> payload);
// Releases the buffer without sending a packet.
diff --git a/pw_rpc/pw_rpc_private/internal_test_utils.h b/pw_rpc/pw_rpc_private/internal_test_utils.h
index f9b227d..ff7290f 100644
--- a/pw_rpc/pw_rpc_private/internal_test_utils.h
+++ b/pw_rpc/pw_rpc_private/internal_test_utils.h
@@ -112,7 +112,7 @@
}
internal::ServerCall& get() { return context_; }
- const auto& output() const { return output_; }
+ auto& output() { return output_; }
TestServer& server() { return static_cast<TestServer&>(server_); }
private:
diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc
index c72dcfe..f3f0d4e 100644
--- a/pw_rpc/raw/raw_method_test.cc
+++ b/pw_rpc/raw/raw_method_test.cc
@@ -72,6 +72,7 @@
int64_t integer;
uint32_t status_code;
} last_request;
+
RawServerWriter last_writer;
void DecodeRawTestRequest(ConstByteSpan request) {
@@ -163,7 +164,7 @@
EXPECT_EQ(777, last_request.integer);
EXPECT_EQ(2u, last_request.status_code);
EXPECT_TRUE(last_writer.open());
- last_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_writer.Finish());
}
TEST(RawServerWriter, Write_SendsPreviouslyAcquiredBuffer) {
@@ -208,11 +209,11 @@
TEST(RawServerWriter, Write_Closed_ReturnsFailedPrecondition) {
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
- ServerContextForTest<FakeService, 16> context(method);
+ ServerContextForTest<FakeService> context(method);
method.Invoke(context.get(), context.packet({}));
- last_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_writer.Finish());
constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
EXPECT_EQ(last_writer.Write(data), Status::FailedPrecondition());
}
diff --git a/pw_rpc/raw/raw_method_union_test.cc b/pw_rpc/raw/raw_method_union_test.cc
index 0a67d65..f048366 100644
--- a/pw_rpc/raw/raw_method_union_test.cc
+++ b/pw_rpc/raw/raw_method_union_test.cc
@@ -139,7 +139,7 @@
EXPECT_EQ(777, last_request.integer);
EXPECT_EQ(2u, last_request.status_code);
EXPECT_TRUE(last_writer.open());
- last_writer.Finish();
+ EXPECT_EQ(OkStatus(), last_writer.Finish());
}
} // namespace