pw_rpc: BaseServerWriter class
BaseServerWriter handles streaming responses from a server. The
pw_rpc server implementation must extend BaseServerWriter to provide a
user-facing method for actually sending responses.
Change-Id: I80e73ff5847ff5c843e496ada08245db8049afb5
diff --git a/pw_rpc/BUILD b/pw_rpc/BUILD
index fdd6f81..14f09e3 100644
--- a/pw_rpc/BUILD
+++ b/pw_rpc/BUILD
@@ -25,6 +25,8 @@
pw_cc_library(
name = "pw_rpc",
srcs = [
+ "base_server_writer.cc",
+ "public/pw_rpc/internal/base_server_writer.h",
"public/pw_rpc/internal/service.h",
"public/pw_rpc/internal/service_registry.h",
"server.cc",
@@ -72,6 +74,27 @@
],
)
+pw_cc_library(
+ name = "test_utils",
+ hdrs = ["pw_rpc_private/test_utils.h"],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":common",
+ "//pw_span",
+ ],
+)
+
+pw_cc_test(
+ name = "base_server_writer_test",
+ srcs = [
+ "base_server_writer_test.cc",
+ ],
+ deps = [
+ ":pw_rpc",
+ ":test_utils",
+ ],
+)
+
pw_cc_test(
name = "packet_test",
srcs = [
@@ -97,6 +120,7 @@
],
deps = [
":pw_rpc",
+ ":test_utils",
"//pw_assert",
],
)
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 37aeba1..99ef678 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -52,6 +52,8 @@
"public/pw_rpc/server_context.h",
]
sources = [
+ "base_server_writer.cc",
+ "public/pw_rpc/internal/base_server_writer.h",
"public/pw_rpc/internal/service.h",
"public/pw_rpc/internal/service_registry.h",
"server.cc",
@@ -83,6 +85,15 @@
visibility = [ ":*" ]
}
+source_set("test_utils") {
+ public = [ "pw_rpc_private/test_utils.h" ]
+ public_deps = [
+ ":common",
+ dir_pw_span,
+ ]
+ visibility = [ ":*" ]
+}
+
pw_proto_library("protos") {
sources = [ "pw_rpc_protos/packet.proto" ]
}
@@ -93,6 +104,7 @@
pw_test_group("tests") {
tests = [
+ ":base_server_writer_test",
":packet_test",
":server_test",
]
@@ -116,6 +128,14 @@
visibility = [ ":test_server" ]
}
+pw_test("base_server_writer_test") {
+ deps = [
+ ":test_server",
+ ":test_utils",
+ ]
+ sources = [ "base_server_writer_test.cc" ]
+}
+
pw_test("packet_test") {
deps = [
":common",
@@ -128,6 +148,7 @@
deps = [
":protos_pwpb",
":test_server",
+ ":test_utils",
dir_pw_assert,
]
sources = [ "server_test.cc" ]
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
new file mode 100644
index 0000000..49b9a37
--- /dev/null
+++ b/pw_rpc/base_server_writer.cc
@@ -0,0 +1,80 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "pw_rpc/internal/base_server_writer.h"
+
+#include "pw_assert/assert.h"
+#include "pw_rpc/internal/method.h"
+#include "pw_rpc/internal/packet.h"
+#include "pw_rpc/server.h"
+
+namespace pw::rpc::internal {
+
+BaseServerWriter& BaseServerWriter::operator=(BaseServerWriter&& other) {
+ context_ = std::move(other.context_);
+ response_ = std::move(other.response_);
+ state_ = std::move(other.state_);
+ other.state_ = kClosed;
+ return *this;
+}
+
+void BaseServerWriter::close() {
+ if (open()) {
+ // TODO(hepler): Send a control packet indicating that the stream has
+ // terminated, and remove this ServerWriter from the Server's list.
+
+ state_ = kClosed;
+ }
+}
+
+span<std::byte> BaseServerWriter::AcquireBuffer() {
+ if (!open()) {
+ return {};
+ }
+
+ PW_DCHECK(response_.empty());
+ response_ = context_.channel_->AcquireBuffer();
+
+ // Reserve space for the RPC packet header.
+ return packet().PayloadUsableSpace(response_);
+}
+
+Status BaseServerWriter::SendAndReleaseBuffer(span<const std::byte> payload) {
+ if (!open()) {
+ return Status::FAILED_PRECONDITION;
+ }
+
+ Packet response_packet = packet();
+ response_packet.set_payload(payload);
+ StatusWithSize encoded = response_packet.Encode(response_);
+ response_ = {};
+
+ if (!encoded.ok()) {
+ context_.channel_->SendAndReleaseBuffer(0);
+ return Status::INTERNAL;
+ }
+
+ // TODO(hepler): Should Channel::SendAndReleaseBuffer return Status?
+ context_.channel_->SendAndReleaseBuffer(encoded.size());
+ return Status::OK;
+}
+
+Packet BaseServerWriter::packet() const {
+ return Packet(PacketType::RPC,
+ context_.channel_id(),
+ context_.service_->id(),
+ method().id());
+}
+
+} // namespace pw::rpc::internal
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
new file mode 100644
index 0000000..6789695
--- /dev/null
+++ b/pw_rpc/base_server_writer_test.cc
@@ -0,0 +1,126 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "pw_rpc/internal/base_server_writer.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "pw_rpc/internal/packet.h"
+#include "pw_rpc/internal/service.h"
+#include "pw_rpc/server_context.h"
+#include "pw_rpc_private/test_utils.h"
+
+namespace pw::rpc {
+
+class TestService : public internal::Service {
+ public:
+ constexpr TestService(uint32_t id)
+ : Service(id, span(&method, 1)), method(8) {}
+
+ internal::Method method;
+};
+
+namespace internal {
+namespace {
+
+using std::byte;
+
+TEST(BaseServerWriter, ConstructWithContext_StartsOpen) {
+ ServerContextForTest<TestService> context;
+
+ BaseServerWriter writer(context.get());
+
+ EXPECT_TRUE(writer.open());
+}
+
+TEST(BaseServerWriter, Move_ClosesOriginal) {
+ ServerContextForTest<TestService> context;
+
+ BaseServerWriter moved(context.get());
+ BaseServerWriter writer(std::move(moved));
+
+ EXPECT_FALSE(moved.open());
+ EXPECT_TRUE(writer.open());
+}
+
+class FakeServerWriter : public BaseServerWriter {
+ public:
+ constexpr FakeServerWriter(ServerContext& context)
+ : BaseServerWriter(context) {}
+
+ constexpr FakeServerWriter() = default;
+
+ Status Write(span<const byte> response) {
+ span buffer = AcquireBuffer();
+ std::memcpy(buffer.data(),
+ response.data(),
+ std::min(buffer.size(), response.size()));
+ return SendAndReleaseBuffer(buffer.first(response.size()));
+ }
+};
+
+TEST(ServerWriter, DefaultConstruct_Closed) {
+ FakeServerWriter writer;
+
+ EXPECT_FALSE(writer.open());
+}
+
+TEST(ServerWriter, Close) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ ASSERT_TRUE(writer.open());
+ writer.close();
+ EXPECT_FALSE(writer.open());
+}
+
+TEST(ServerWriter, Open_SendsPacketWithPayload) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
+ ASSERT_EQ(Status::OK, writer.Write(data));
+
+ byte encoded[64];
+ auto sws = Packet(PacketType::RPC,
+ context.kChannelId,
+ context.kServiceId,
+ 8,
+ data,
+ Status::OK)
+ .Encode(encoded);
+ ASSERT_EQ(Status::OK, sws.status());
+
+ EXPECT_EQ(sws.size(), context.output().sent_packet().size());
+ EXPECT_EQ(
+ 0,
+ std::memcmp(encoded, context.output().sent_packet().data(), sws.size()));
+}
+
+TEST(ServerWriter, Closed_IgnoresPacket) {
+ ServerContextForTest<TestService> context;
+ FakeServerWriter writer(context.get());
+
+ writer.close();
+
+ constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
+ EXPECT_EQ(Status::FAILED_PRECONDITION, writer.Write(data));
+}
+
+} // namespace
+} // namespace internal
+} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h
index 133aa4f..f0f890f 100644
--- a/pw_rpc/public/pw_rpc/channel.h
+++ b/pw_rpc/public/pw_rpc/channel.h
@@ -20,6 +20,11 @@
#include "pw_status/status.h"
namespace pw::rpc {
+namespace internal {
+
+class BaseServerWriter;
+
+} // namespace internal
class ChannelOutput {
public:
@@ -60,6 +65,7 @@
private:
friend class Server;
+ friend class internal::BaseServerWriter;
span<std::byte> AcquireBuffer() const { return output_->AcquireBuffer(); }
void SendAndReleaseBuffer(size_t size) const {
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
new file mode 100644
index 0000000..f696ed9
--- /dev/null
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include <cstddef>
+#include <utility>
+
+#include "pw_rpc/channel.h"
+#include "pw_rpc/server_context.h"
+#include "pw_span/span.h"
+
+namespace pw::rpc::internal {
+
+class Method;
+class Packet;
+
+// Internal ServerWriter base class. ServerWriters are used to stream responses.
+// Implementations must provide a derived class that provides the interface for
+// sending responses.
+//
+// TODO(hepler): ServerWriters need to be tracked by the server to support
+// cancelling / terminating ongoing streaming RPCs.
+class BaseServerWriter {
+ public:
+ constexpr BaseServerWriter(ServerContext& context)
+ : context_(context), state_{kOpen} {}
+
+ BaseServerWriter(BaseServerWriter&& other) { *this = std::move(other); }
+ BaseServerWriter& operator=(BaseServerWriter&& other);
+
+ BaseServerWriter(const BaseServerWriter&) = delete;
+ BaseServerWriter& operator=(const BaseServerWriter&) = delete;
+
+ // True if the ServerWriter is active and ready to send responses.
+ bool open() const { return state_ == kOpen; }
+
+ // Closes the ServerWriter, if it is open.
+ void close();
+
+ protected:
+ constexpr BaseServerWriter() : state_{kClosed} {}
+
+ const Method& method() const { return *context_.method_; }
+
+ span<std::byte> AcquireBuffer();
+
+ Status SendAndReleaseBuffer(span<const std::byte> payload);
+
+ private:
+ Packet packet() const;
+
+ ServerContext context_;
+ span<std::byte> response_;
+ enum { kClosed, kOpen } state_;
+};
+
+} // namespace pw::rpc::internal
diff --git a/pw_rpc/public/pw_rpc/server_context.h b/pw_rpc/public/pw_rpc/server_context.h
index 2dcc3e0..cebbf65 100644
--- a/pw_rpc/public/pw_rpc/server_context.h
+++ b/pw_rpc/public/pw_rpc/server_context.h
@@ -13,6 +13,8 @@
// the License.
#pragma once
+#include <cstdint>
+
#include "pw_assert/assert.h"
#include "pw_rpc/channel.h"
@@ -21,15 +23,13 @@
class Method;
class Service;
+class BaseServerWriter;
} // namespace internal
// The ServerContext collects context for an RPC being invoked on a server.
class ServerContext {
public:
- constexpr ServerContext()
- : channel_(nullptr), service_(nullptr), method_(nullptr) {}
-
uint32_t channel_id() const {
PW_DCHECK_NOTNULL(channel_);
return channel_->id();
@@ -37,15 +37,20 @@
private:
friend class Server;
+ friend class internal::BaseServerWriter;
+ template <typename, uint32_t, uint32_t>
+ friend class ServerContextForTest;
+
+ constexpr ServerContext()
+ : channel_(nullptr), service_(nullptr), method_(nullptr) {}
constexpr ServerContext(Channel& channel,
internal::Service& service,
const internal::Method& method)
- : channel_(&channel), service_(&service), method_(&method) {
- // Temporarily silence unused private field warnings.
- (void)service_;
- (void)method_;
- }
+ : channel_(&channel), service_(&service), method_(&method) {}
+
+ constexpr ServerContext(const ServerContext&) = default;
+ constexpr ServerContext& operator=(const ServerContext&) = default;
Channel* channel_;
internal::Service* service_;
diff --git a/pw_rpc/pw_rpc_private/test_utils.h b/pw_rpc/pw_rpc_private/test_utils.h
new file mode 100644
index 0000000..c4b1a66
--- /dev/null
+++ b/pw_rpc/pw_rpc_private/test_utils.h
@@ -0,0 +1,72 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include <cstddef>
+
+#include "pw_rpc/channel.h"
+#include "pw_span/span.h"
+
+namespace pw::rpc {
+
+template <size_t buffer_size>
+class TestOutput : public ChannelOutput {
+ public:
+ constexpr TestOutput(uint32_t id) : ChannelOutput(id), sent_packet_{} {}
+
+ span<std::byte> AcquireBuffer() override { return buffer_; }
+
+ void SendAndReleaseBuffer(size_t size) override {
+ sent_packet_ = {buffer_, size};
+ }
+
+ const span<const std::byte>& sent_packet() const { return sent_packet_; }
+
+ private:
+ std::byte buffer_[buffer_size];
+ span<const std::byte> sent_packet_;
+};
+
+namespace internal {
+
+class Method;
+
+} // namespace internal
+
+template <typename Service, uint32_t channel_id = 99, uint32_t service_id = 16>
+class ServerContextForTest {
+ public:
+ static constexpr uint32_t kChannelId = channel_id;
+ static constexpr uint32_t kServiceId = service_id;
+
+ ServerContextForTest(const internal::Method& method)
+ : output_(32),
+ channel_(Channel::Create<kChannelId>(&output_)),
+ service_(kServiceId),
+ context_(channel_, service_, method) {}
+
+ ServerContextForTest() : ServerContextForTest(service_.method) {}
+
+ ServerContext& get() { return context_; }
+ const auto& output() const { return output_; }
+
+ private:
+ TestOutput<128> output_;
+ Channel channel_;
+ Service service_;
+
+ ServerContext context_;
+};
+
+} // namespace pw::rpc
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index b242088..389c9ca 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -21,6 +21,7 @@
#include "pw_assert/assert.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/service.h"
+#include "pw_rpc_private/test_utils.h"
namespace pw::rpc {
namespace {
@@ -31,24 +32,6 @@
using internal::Packet;
using internal::PacketType;
-template <size_t buffer_size>
-class TestOutput : public ChannelOutput {
- public:
- constexpr TestOutput(uint32_t id) : ChannelOutput(id), sent_packet_({}) {}
-
- span<byte> AcquireBuffer() override { return buffer_; }
-
- void SendAndReleaseBuffer(size_t size) override {
- sent_packet_ = {buffer_, size};
- }
-
- span<const byte> sent_packet() const { return sent_packet_; }
-
- private:
- byte buffer_[buffer_size];
- span<const byte> sent_packet_;
-};
-
class TestService : public internal::Service {
public:
TestService(uint32_t service_id)