pw_rpc: Update protocol for server streams
Update the server-to-client RPC packet types so that the packet type
unambiguously indicates whether it is the first or last packet. This is
already the case for client-to-server RPC packet types.
- Have RESPONSE always be the last packet in the stream. For RPCs
without a server stream, it includes a payload. Remove
SERVER_STREAM_END.
- Introduce the SERVER_STREAM packet, to parallel the CLIENT_STREAM
packet.
- Update the server and client code and tests. Test that old-style
streaming RPCs still work correctly.
- Refactor the duplicate MessageOutput class into a FakeChannelOutput
used by both Nanopb and raw RPCs.
- In C++, don't encode default-valued payload and status fields.
Change-Id: I218772dad6c2981dda5f032f298ea43ee5e08b4d
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/49822
Commit-Queue: Wyatt Hepler <hepler@google.com>
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/BUILD.bazel b/pw_rpc/BUILD.bazel
index 82d9230..6cdf448 100644
--- a/pw_rpc/BUILD.bazel
+++ b/pw_rpc/BUILD.bazel
@@ -44,14 +44,14 @@
pw_cc_library(
name = "server",
srcs = [
- "responder.cc",
- "public/pw_rpc/internal/responder.h",
"public/pw_rpc/internal/call.h",
"public/pw_rpc/internal/hash.h",
"public/pw_rpc/internal/method.h",
"public/pw_rpc/internal/method_lookup.h",
"public/pw_rpc/internal/method_union.h",
+ "public/pw_rpc/internal/responder.h",
"public/pw_rpc/internal/server.h",
+ "responder.cc",
"server.cc",
"service.cc",
],
@@ -115,8 +115,10 @@
pw_cc_library(
name = "internal_test_utils",
+ srcs = ["fake_channel_output.cc"],
hdrs = [
"public/pw_rpc/internal/test_method.h",
+ "pw_rpc_private/fake_channel_output.h",
"pw_rpc_private/internal_test_utils.h",
"pw_rpc_private/method_impl_tester.h",
],
@@ -128,6 +130,8 @@
deps = [
":client",
":server",
+ "//pw_assert",
+ "//pw_bytes",
"//pw_span",
],
)
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index d5f4fd5..25953c4 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -127,13 +127,20 @@
pw_source_set("test_utils") {
public = [
"public/pw_rpc/internal/test_method.h",
+ "pw_rpc_private/fake_channel_output.h",
"pw_rpc_private/internal_test_utils.h",
"pw_rpc_private/method_impl_tester.h",
]
- public_configs = [ ":private_includes" ]
+ sources = [ "fake_channel_output.cc" ]
+ public_configs = [
+ ":public_include_path",
+ ":private_includes",
+ ]
public_deps = [
":client",
":server",
+ dir_pw_assert,
+ dir_pw_bytes,
]
visibility = [ "./*" ]
}
diff --git a/pw_rpc/CMakeLists.txt b/pw_rpc/CMakeLists.txt
index 4a34370..c1b9658 100644
--- a/pw_rpc/CMakeLists.txt
+++ b/pw_rpc/CMakeLists.txt
@@ -72,8 +72,16 @@
pw_sync.mutex
)
-add_library(pw_rpc.test_utils INTERFACE)
-target_include_directories(pw_rpc.test_utils INTERFACE .)
+pw_add_module_library(pw_rpc.test_utils
+ SOURCES
+ fake_channel_output.cc
+ PUBLIC_DEPS
+ pw_assert
+ pw_bytes
+ pw_rpc.client
+ pw_rpc.server
+)
+target_include_directories(pw_rpc.test_utils PUBLIC .)
pw_proto_library(pw_rpc.protos
SOURCES
diff --git a/pw_rpc/channel_test.cc b/pw_rpc/channel_test.cc
index 7687a2d..9dc0c20 100644
--- a/pw_rpc/channel_test.cc
+++ b/pw_rpc/channel_test.cc
@@ -37,10 +37,11 @@
EXPECT_EQ(nullptr, NameTester(nullptr).name());
}
-constexpr Packet kTestPacket(PacketType::RESPONSE, 1, 42, 100);
+constexpr Packet kTestPacket(
+ PacketType::RESPONSE, 1, 42, 100, {}, Status::NotFound());
const size_t kReservedSize = 2 /* type */ + 2 /* channel */ + 5 /* service */ +
5 /* method */ + 2 /* payload key */ +
- 2 /* status */;
+ 2 /* status (if not OK) */;
enum class ChannelId {
kOne = 1,
@@ -67,7 +68,7 @@
}
TEST(Channel, OutputBuffer_TooSmall) {
- TestOutput<kReservedSize - 1> output;
+ TestOutput<kReservedSize - 2 /* payload key & size */ - 1> output;
internal::Channel channel(100, &output);
Channel::OutputBuffer output_buffer = channel.AcquireBuffer();
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index b5f3f5d..b49a9a3 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -70,10 +70,10 @@
case PacketType::RESPONSE:
case PacketType::SERVER_ERROR:
call->HandleResponse(packet);
+ call->Unregister();
break;
- case PacketType::SERVER_STREAM_END:
+ case PacketType::SERVER_STREAM:
call->HandleResponse(packet);
- RemoveCall(*call);
break;
default:
return Status::Unimplemented();
diff --git a/pw_rpc/docs.rst b/pw_rpc/docs.rst
index 6ff8e43..1383159 100644
--- a/pw_rpc/docs.rst
+++ b/pw_rpc/docs.rst
@@ -374,6 +374,15 @@
| | - payload |
| | |
+-------------------+-------------------------------------+
+| CLIENT_STREAM_END | Client stream is complete |
+| | |
+| | .. code-block:: text |
+| | |
+| | - channel_id |
+| | - service_id |
+| | - method_id |
+| | |
++-------------------+-------------------------------------+
| CLIENT_ERROR | Received unexpected packet |
| | |
| | .. code-block:: text |
@@ -393,15 +402,6 @@
| | - method_id |
| | |
+-------------------+-------------------------------------+
-| CLIENT_STREAM_END | Client stream is complete |
-| | |
-| | .. code-block:: text |
-| | |
-| | - channel_id |
-| | - service_id |
-| | - method_id |
-| | |
-+-------------------+-------------------------------------+
**Errors**
@@ -421,7 +421,19 @@
+-------------------+-------------------------------------+
| packet type | description |
+===================+=====================================+
-| RESPONSE | RPC response |
+| RESPONSE | The RPC is complete |
+| | |
+| | .. code-block:: text |
+| | |
+| | - channel_id |
+| | - service_id |
+| | - method_id |
+| | - status |
+| | - payload |
+| | (unary & client streaming only) |
+| | |
++-------------------+-------------------------------------+
+| SERVER_STREAM | Message in a server stream |
| | |
| | .. code-block:: text |
| | |
@@ -429,18 +441,6 @@
| | - service_id |
| | - method_id |
| | - payload |
-| | - status |
-| | (unary & client streaming only) |
-| | |
-+-------------------+-------------------------------------+
-| SERVER_STREAM_END | Server stream and RPC finished |
-| | |
-| | .. code-block:: text |
-| | |
-| | - channel_id |
-| | - service_id |
-| | - method_id |
-| | - status |
| | |
+-------------------+-------------------------------------+
| SERVER_ERROR | Received unexpected packet |
@@ -474,6 +474,21 @@
the protocol for calling service methods of each type: unary, server streaming,
client streaming, and bidirectional streaming.
+The basic flow for all RPC invocations is as follows:
+
+ * Client sends a ``REQUEST`` packet. Includes a payload for unary & server
+ streaming RPCs.
+ * For client and bidirectional streaming RPCs, the client may send any number
+ of ``CLIENT_STREAM`` packets with payloads.
+ * For server and bidirectional streaming RPCs, the server may send any number
+ of ``SERVER_STREAM`` packets.
+ * The server sends a ``RESPONSE`` packet. Includes a payload for unary &
+ client streaming RPCs. The RPC is complete.
+
+The client may cancel an ongoing RPC at any time by sending a ``CANCEL`` packet.
+The server may finish an ongoing RPC at any time by sending the ``RESPONSE``
+packet.
+
Unary RPC
^^^^^^^^^
In a unary RPC, the client sends a single request and the server sends a single
@@ -522,7 +537,7 @@
Server streaming RPC
^^^^^^^^^^^^^^^^^^^^
In a server streaming RPC, the client sends a single request and the server
-sends any number of responses followed by a ``SERVER_STREAM_END`` packet.
+sends any number of ``SERVER_STREAM`` packets followed by a ``RESPONSE`` packet.
.. seqdiag::
:scale: 110
@@ -538,12 +553,12 @@
client <-- server [
noactivate,
label = "messages (zero or more)",
- rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload"
+ rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload"
];
client <- server [
label = "done",
- rightnote = "PacketType.SERVER_STREAM_END\nchannel ID\nservice ID\nmethod ID\nstatus"
+ rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\nstatus"
];
}
@@ -564,7 +579,7 @@
client <-- server [
noactivate,
label = "messages (zero or more)",
- rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload"
+ rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload"
];
client -> server [
@@ -579,7 +594,7 @@
In a client streaming RPC, the client starts the RPC by sending a ``REQUEST``
packet with no payload. It then sends any number of messages in
``CLIENT_STREAM`` packets, followed by a ``CLIENT_STREAM_END``. The server sends
-a single response to finish the RPC.
+a single ``RESPONSE`` to finish the RPC.
.. seqdiag::
:scale: 110
@@ -643,8 +658,8 @@
In a bidirectional streaming RPC, the client sends any number of requests and
the server sends any number of responses. The client invokes the RPC by sending
a ``REQUEST`` with no payload. It sends a ``CLIENT_STREAM_END`` packet when it
-has finished sending requests. The server sends a ``SERVER_STREAM_END`` packet
-to finish the RPC.
+has finished sending requests. The server sends a ``RESPONSE`` packet to finish
+the RPC.
.. seqdiag::
:scale: 110
@@ -668,7 +683,7 @@
client <-- server [
noactivate,
label = "messages (zero or more)",
- rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload"
+ rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload"
];
client -> server [
@@ -679,13 +694,13 @@
client <- server [
label = "done",
- rightnote = "PacketType.SERVER_STREAM_END\nchannel ID\nservice ID\nmethod ID\nstatus"
+ rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\nstatus"
];
}
-The server may finish the RPC at any time by sending the ``SERVER_STREAM_END``
-packet, even if it has not received the ``CLIENT_STREAM_END`` packet. The client
-may terminate the RPC at any time by sending a ``CANCEL`` packet.
+The server may finish the RPC at any time by sending the ``RESPONSE`` packet,
+even if it has not received the ``CLIENT_STREAM_END`` packet. The client may
+terminate the RPC at any time by sending a ``CANCEL`` packet.
.. seqdiag::
:scale: 110
@@ -707,13 +722,14 @@
client <-- server [
noactivate,
label = "messages (zero or more)",
- rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload"
+ rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload"
];
client -> server [
noactivate,
label = "cancel",
- leftnote = "PacketType.CANCEL\nchannel ID\nservice ID\nmethod ID" ];
+ leftnote = "PacketType.CANCEL\nchannel ID\nservice ID\nmethod ID"
+ ];
}
RPC server
diff --git a/pw_rpc/fake_channel_output.cc b/pw_rpc/fake_channel_output.cc
new file mode 100644
index 0000000..e68a879
--- /dev/null
+++ b/pw_rpc/fake_channel_output.cc
@@ -0,0 +1,62 @@
+// Copyright 2021 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "pw_rpc_private/fake_channel_output.h"
+
+#include "pw_assert/check.h"
+#include "pw_result/result.h"
+#include "pw_rpc/internal/packet.h"
+
+namespace pw::rpc::internal::test {
+
+void FakeChannelOutput::clear() {
+ ClearResponses();
+ total_responses_ = 0;
+ last_status_ = Status::Unknown();
+ done_ = false;
+}
+
+Status FakeChannelOutput::SendAndReleaseBuffer(
+ std::span<const std::byte> buffer) {
+ PW_CHECK(!done_);
+ PW_CHECK_PTR_EQ(buffer.data(), packet_buffer_.data());
+
+ if (buffer.empty()) {
+ return OkStatus();
+ }
+
+ Result<Packet> result = Packet::FromBuffer(buffer);
+ PW_CHECK_OK(result.status());
+
+ last_status_ = result.value().status();
+
+ switch (result.value().type()) {
+ case PacketType::RESPONSE:
+ // Server streaming RPCs don't have a payload in their response packet.
+ if (!server_streaming_) {
+ ProcessResponse(result.value().payload());
+ }
+ done_ = true;
+ break;
+ case PacketType::SERVER_STREAM:
+ ProcessResponse(result.value().payload());
+ break;
+ default:
+ PW_CRASH("Unhandled PacketType %d",
+ static_cast<int>(result.value().type()));
+ }
+ return OkStatus();
+}
+
+} // namespace pw::rpc::internal::test
diff --git a/pw_rpc/internal/packet.proto b/pw_rpc/internal/packet.proto
index 4438247..ebcc93e 100644
--- a/pw_rpc/internal/packet.proto
+++ b/pw_rpc/internal/packet.proto
@@ -23,10 +23,11 @@
// Client-to-server packets
- // A request from a client for a service method.
+ // The client invokes an RPC. Always the first packet.
REQUEST = 0;
- // A message in a client stream.
+ // A message in a client stream. Always sent after a REQUEST and before a
+ // CLIENT_STREAM_END.
CLIENT_STREAM = 2;
// The client received a packet for an RPC it did not request.
@@ -40,14 +41,18 @@
// Server-to-client packets
- // A response from a server for a service method.
+ // The RPC has finished.
RESPONSE = 1;
- // A server streaming or bidirectional RPC has completed.
- SERVER_STREAM_END = 3;
+ // Deprecated, do not use. Formerly was used as the last packet in a server
+ // stream.
+ DEPRECATED_SERVER_STREAM_END = 3;
// The server was unable to process a request.
SERVER_ERROR = 5;
+
+ // A message in a server stream.
+ SERVER_STREAM = 7;
}
message RpcPacket {
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index 72cc611..17e11b9 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -72,6 +72,7 @@
public = [ "public/pw_rpc/nanopb_test_method_context.h" ]
public_deps = [
":method",
+ "..:test_utils",
dir_pw_assert,
dir_pw_containers,
]
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc
index 12ea706..36f0a4c 100644
--- a/pw_rpc/nanopb/codegen_test.cc
+++ b/pw_rpc/nanopb/codegen_test.cc
@@ -215,12 +215,11 @@
PW_ENCODE_PB(
pw_rpc_test_TestStreamResponse, response, .chunk = {}, .number = 11u);
- context.SendResponse(OkStatus(), response);
+ context.SendServerStream(response);
EXPECT_TRUE(result.active);
EXPECT_EQ(result.response_value, 11);
- context.SendPacket(internal::PacketType::SERVER_STREAM_END,
- Status::NotFound());
+ context.SendResponse(Status::NotFound());
EXPECT_FALSE(result.active);
EXPECT_EQ(result.stream_status, Status::NotFound());
}
diff --git a/pw_rpc/nanopb/nanopb_client_call_test.cc b/pw_rpc/nanopb/nanopb_client_call_test.cc
index ddac23b..c9d87d2 100644
--- a/pw_rpc/nanopb/nanopb_client_call_test.cc
+++ b/pw_rpc/nanopb/nanopb_client_call_test.cc
@@ -248,19 +248,19 @@
});
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
- context.SendResponse(OkStatus(), r1);
+ context.SendServerStream(r1);
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
- context.SendResponse(OkStatus(), r2);
+ context.SendServerStream(r2);
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
- context.SendResponse(OkStatus(), r3);
+ context.SendServerStream(r3);
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 3);
EXPECT_EQ(last_response_number_, 33);
@@ -283,19 +283,18 @@
});
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
- context.SendResponse(OkStatus(), r1);
+ context.SendServerStream(r1);
EXPECT_TRUE(active_);
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
- context.SendResponse(OkStatus(), r2);
+ context.SendServerStream(r2);
EXPECT_TRUE(active_);
// Close the stream.
- context.SendPacket(internal::PacketType::SERVER_STREAM_END,
- Status::NotFound());
+ context.SendResponse(Status::NotFound());
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
- context.SendResponse(OkStatus(), r3);
+ context.SendServerStream(r3);
EXPECT_FALSE(active_);
EXPECT_EQ(responses_received_, 2);
@@ -316,19 +315,19 @@
[this](Status error) { rpc_error_ = error; });
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
- context.SendResponse(OkStatus(), r1);
+ context.SendServerStream(r1);
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);
constexpr std::byte bad_payload[]{
std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
- context.SendResponse(OkStatus(), bad_payload);
+ context.SendServerStream(bad_payload);
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(rpc_error_, Status::DataLoss());
PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
- context.SendResponse(OkStatus(), r2);
+ context.SendServerStream(r2);
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);
diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc
index d38c461..d212cba 100644
--- a/pw_rpc/nanopb/nanopb_method_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_test.cc
@@ -168,7 +168,7 @@
const NanopbMethod& method =
std::get<1>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
const Packet& response = context.output().sent_packet();
EXPECT_EQ(response.status(), Status::Unauthenticated());
@@ -190,7 +190,7 @@
const NanopbMethod& method =
std::get<0>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet(bad_payload));
+ method.Invoke(context.get(), context.request(bad_payload));
const Packet& packet = context.output().sent_packet();
EXPECT_EQ(PacketType::SERVER_ERROR, packet.type());
@@ -208,10 +208,11 @@
std::get<1>(FakeService::kMethods).nanopb_method();
// Output buffer is too small for the response, but can fit an error packet.
ServerContextForTest<FakeService, 22> context(method);
- ASSERT_LT(context.output().buffer_size(),
- context.packet(request).MinEncodedSizeBytes() + request.size() + 1);
+ ASSERT_LT(
+ context.output().buffer_size(),
+ context.request(request).MinEncodedSizeBytes() + request.size() + 1);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
const Packet& packet = context.output().sent_packet();
EXPECT_EQ(PacketType::SERVER_ERROR, packet.type());
@@ -230,7 +231,7 @@
std::get<2>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
EXPECT_EQ(0u, context.output().packet_count());
EXPECT_EQ(555, last_request.integer);
@@ -241,13 +242,13 @@
std::get<2>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
EXPECT_EQ(OkStatus(), last_writer.Write({.value = 100}));
PW_ENCODE_PB(pw_rpc_test_TestResponse, payload, .value = 100);
std::array<byte, 128> encoded_response = {};
- auto encoded = context.packet(payload).Encode(encoded_response);
+ auto encoded = context.server_stream(payload).Encode(encoded_response);
ASSERT_EQ(OkStatus(), encoded.status());
ASSERT_EQ(encoded.value().size(), context.output().sent_data().size());
@@ -262,7 +263,7 @@
std::get<2>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
EXPECT_EQ(OkStatus(), last_writer.Finish());
EXPECT_TRUE(last_writer.Write({.value = 100}).IsFailedPrecondition());
@@ -273,7 +274,7 @@
std::get<2>(FakeService::kMethods).nanopb_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
ServerWriter<pw_rpc_test_TestResponse> new_writer = std::move(last_writer);
EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100}));
@@ -289,20 +290,20 @@
const NanopbMethod& method =
std::get<2>(FakeService::kMethods).nanopb_method();
- constexpr size_t kNoPayloadPacketSize = 2 /* type */ + 2 /* channel */ +
- 5 /* service */ + 5 /* method */ +
- 2 /* payload */ + 2 /* status */;
+ constexpr size_t kNoPayloadPacketSize =
+ 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ +
+ 0 /* payload (when empty) */ + 0 /* status (when OK)*/;
// Make the buffer barely fit a packet with no payload.
ServerContextForTest<FakeService, kNoPayloadPacketSize> context(method);
// Verify that the encoded size of a packet with an empty payload is correct.
std::array<byte, 128> encoded_response = {};
- auto encoded = context.packet({}).Encode(encoded_response);
+ auto encoded = context.request({}).Encode(encoded_response);
ASSERT_EQ(OkStatus(), encoded.status());
ASSERT_EQ(kNoPayloadPacketSize, encoded.value().size());
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
EXPECT_EQ(OkStatus(), last_writer.Write({})); // Barely fits
EXPECT_EQ(Status::Internal(), last_writer.Write({.value = 1})); // Too big
diff --git a/pw_rpc/nanopb/nanopb_method_union_test.cc b/pw_rpc/nanopb/nanopb_method_union_test.cc
index d1bfd7f..cb3ef71 100644
--- a/pw_rpc/nanopb/nanopb_method_union_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_union_test.cc
@@ -92,7 +92,7 @@
const Method& method =
std::get<0>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
const Packet& response = context.output().sent_packet();
EXPECT_EQ(response.status(), Status::Unknown());
@@ -106,12 +106,11 @@
std::get<1>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
EXPECT_TRUE(last_raw_writer.open());
EXPECT_EQ(OkStatus(), last_raw_writer.Finish());
- EXPECT_EQ(context.output().sent_packet().type(),
- PacketType::SERVER_STREAM_END);
+ EXPECT_EQ(context.output().sent_packet().type(), PacketType::RESPONSE);
}
TEST(NanopbMethodUnion, Nanopb_CallsUnaryMethod) {
@@ -121,7 +120,7 @@
const Method& method =
std::get<2>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
const Packet& response = context.output().sent_packet();
EXPECT_EQ(response.status(), Status::Unauthenticated());
@@ -146,14 +145,13 @@
std::get<3>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet(request));
+ method.Invoke(context.get(), context.request(request));
EXPECT_EQ(555, last_request.integer);
EXPECT_TRUE(last_writer.open());
EXPECT_EQ(OkStatus(), last_writer.Finish());
- EXPECT_EQ(context.output().sent_packet().type(),
- PacketType::SERVER_STREAM_END);
+ EXPECT_EQ(context.output().sent_packet().type(), PacketType::RESPONSE);
}
} // namespace
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h
index 1383dc6..e77db5f 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h
@@ -200,8 +200,6 @@
} else {
callbacks_.InvokeRpcError(Status::DataLoss());
}
-
- Unregister();
}
void InvokeServerStreamingCallback(const internal::Packet& packet) {
@@ -210,12 +208,15 @@
return;
}
- if (packet.type() == internal::PacketType::SERVER_STREAM_END &&
- callbacks_.stream_end) {
- callbacks_.stream_end(packet.status());
+ if (packet.type() == internal::PacketType::RESPONSE) {
+ // The server sent the last packet; call the stream end callback.
+ if (callbacks_.stream_end) {
+ callbacks_.stream_end(packet.status());
+ }
return;
}
+ // Handle internal::PacketType::SERVER_STREAM packets.
ResponseBuffer response_struct{};
if (callbacks_.stream_response &&
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
index b65ea22..5d836ec 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
@@ -17,6 +17,7 @@
#include <utility>
#include "pw_assert/assert.h"
+#include "pw_bytes/span.h"
#include "pw_containers/vector.h"
#include "pw_preprocessor/arguments.h"
#include "pw_rpc/channel.h"
@@ -25,6 +26,7 @@
#include "pw_rpc/internal/nanopb_method.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/server.h"
+#include "pw_rpc_private/fake_channel_output.h"
namespace pw::rpc {
@@ -77,7 +79,7 @@
::pw::rpc::internal::Hash(#method), \
##__VA_ARGS__>
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse = 4,
size_t kOutputSizeBytes = 128>
@@ -88,55 +90,46 @@
// A ChannelOutput implementation that stores the outgoing payloads and status.
template <typename Response>
-class MessageOutput final : public ChannelOutput {
+class MessageOutput final : public FakeChannelOutput {
public:
- MessageOutput(const internal::NanopbMethod& method,
+ MessageOutput(const internal::NanopbMethod& kMethod,
Vector<Response>& responses,
- std::span<std::byte> buffer)
- : ChannelOutput("internal::test::nanopb::MessageOutput"),
- method_(method),
- responses_(responses),
- buffer_(buffer) {
- clear();
- }
-
- Status last_status() const { return last_status_; }
- void set_last_status(Status status) { last_status_ = status; }
-
- size_t total_responses() const { return total_responses_; }
-
- bool stream_ended() const { return stream_ended_; }
-
- void clear();
+ ByteSpan packet_buffer,
+ bool server_streaming)
+ : FakeChannelOutput(packet_buffer, server_streaming),
+ method_(kMethod),
+ responses_(responses) {}
private:
- std::span<std::byte> AcquireBuffer() override { return buffer_; }
+ void AppendResponse(ConstByteSpan response) override {
+ // If we run out of space, the back message is always the most recent.
+ responses_.emplace_back();
+ responses_.back() = {};
+ PW_ASSERT(method_.DecodeResponse(response, &responses_.back()));
+ }
- Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override;
+ void ClearResponses() override { responses_.clear(); }
const internal::NanopbMethod& method_;
Vector<Response>& responses_;
- std::span<std::byte> buffer_;
- size_t total_responses_;
- bool stream_ended_;
- Status last_status_;
};
// Collects everything needed to invoke a particular RPC.
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
struct InvocationContext {
- using Request = internal::Request<method>;
- using Response = internal::Response<method>;
+ using Request = internal::Request<kMethod>;
+ using Response = internal::Response<kMethod>;
template <typename... Args>
InvocationContext(Args&&... args)
: output(MethodLookup::GetNanopbMethod<Service, kMethodId>(),
responses,
- buffer),
+ buffer,
+ MethodTraits<decltype(kMethod)>::kServerStreaming),
channel(Channel::Create<123>(&output)),
server(std::span(&channel, 1)),
service(std::forward<Args>(args)...),
@@ -158,10 +151,13 @@
// Method invocation context for a unary RPC. Returns the status in call() and
// provides the response through the response() method.
-template <typename Service, auto method, uint32_t kMethodId, size_t kOutputSize>
+template <typename Service,
+ auto kMethod,
+ uint32_t kMethodId,
+ size_t kOutputSize>
class UnaryContext {
private:
- InvocationContext<Service, method, kMethodId, 1, kOutputSize> ctx_;
+ InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize> ctx_;
public:
using Request = typename decltype(ctx_)::Request;
@@ -177,7 +173,7 @@
ctx_.output.clear();
ctx_.responses.emplace_back();
ctx_.responses.back() = {};
- return CallMethodImplFunction<method>(
+ return CallMethodImplFunction<kMethod>(
ctx_.call, request, ctx_.responses.back());
}
@@ -190,13 +186,14 @@
// Method invocation context for a server streaming RPC.
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
class ServerStreamingContext {
private:
- InvocationContext<Service, method, kMethodId, kMaxResponse, kOutputSize> ctx_;
+ InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize>
+ ctx_;
public:
using Request = typename decltype(ctx_)::Request;
@@ -211,7 +208,7 @@
void call(const Request& request) {
ctx_.output.clear();
internal::Responder server_writer(ctx_.call);
- return CallMethodImplFunction<method>(
+ return CallMethodImplFunction<kMethod>(
ctx_.call,
request,
static_cast<ServerWriter<Response>&>(server_writer));
@@ -235,7 +232,7 @@
size_t total_responses() const { return ctx_.output.total_responses(); }
// True if the stream has terminated.
- bool done() const { return ctx_.output.stream_ended(); }
+ bool done() const { return ctx_.output.done(); }
// The status of the stream. Only valid if done() is true.
Status status() const {
@@ -247,79 +244,41 @@
// Alias to select the type of the context object to use based on which type of
// RPC it is for.
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
using Context = std::tuple_element_t<
- static_cast<size_t>(internal::MethodTraits<decltype(method)>::kType),
- std::tuple<UnaryContext<Service, method, kMethodId, kOutputSize>,
+ static_cast<size_t>(internal::MethodTraits<decltype(kMethod)>::kType),
+ std::tuple<UnaryContext<Service, kMethod, kMethodId, kOutputSize>,
ServerStreamingContext<Service,
- method,
+ kMethod,
kMethodId,
kMaxResponse,
kOutputSize>
// TODO(hepler): Support client and bidi streaming
>>;
-template <typename Response>
-void MessageOutput<Response>::clear() {
- responses_.clear();
- total_responses_ = 0;
- stream_ended_ = false;
- last_status_ = Status::Unknown();
-}
-
-template <typename Response>
-Status MessageOutput<Response>::SendAndReleaseBuffer(
- std::span<const std::byte> buffer) {
- PW_ASSERT(!stream_ended_);
- PW_ASSERT(buffer.data() == buffer_.data());
-
- if (buffer.empty()) {
- return OkStatus();
- }
-
- Result<internal::Packet> result = internal::Packet::FromBuffer(buffer);
- PW_ASSERT(result.ok());
-
- last_status_ = result.value().status();
-
- switch (result.value().type()) {
- case internal::PacketType::RESPONSE:
- // If we run out of space, the back message is always the most recent.
- responses_.emplace_back();
- responses_.back() = {};
- PW_ASSERT(
- method_.DecodeResponse(result.value().payload(), &responses_.back()));
- total_responses_ += 1;
- break;
- case internal::PacketType::SERVER_STREAM_END:
- stream_ended_ = true;
- break;
- default:
- pw_assert_HandleFailure(); // Unhandled PacketType
- }
- return OkStatus();
-}
-
} // namespace internal::test::nanopb
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSizeBytes>
class NanopbTestMethodContext
: public internal::test::nanopb::
- Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes> {
+ Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> {
public:
// Forwards constructor arguments to the service class.
template <typename... ServiceArgs>
NanopbTestMethodContext(ServiceArgs&&... service_args)
- : internal::test::nanopb::
- Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes>(
- std::forward<ServiceArgs>(service_args)...) {}
+ : internal::test::nanopb::Context<Service,
+ kMethod,
+ kMethodId,
+ kMaxResponse,
+ kOutputSizeBytes>(
+ std::forward<ServiceArgs>(service_args)...) {}
};
} // namespace pw::rpc
diff --git a/pw_rpc/packet.cc b/pw_rpc/packet.cc
index 3739370..275acee 100644
--- a/pw_rpc/packet.cc
+++ b/pw_rpc/packet.cc
@@ -74,13 +74,20 @@
RpcPacket::Encoder rpc_packet(&encoder);
// The payload is encoded first, as it may share the encode buffer.
- rpc_packet.WritePayload(payload_);
+ if (!payload_.empty()) {
+ rpc_packet.WritePayload(payload_);
+ }
rpc_packet.WriteType(type_);
rpc_packet.WriteChannelId(channel_id_);
rpc_packet.WriteServiceId(service_id_);
rpc_packet.WriteMethodId(method_id_);
- rpc_packet.WriteStatus(status_.code());
+
+ // Status code 0 is OK. In protobufs, 0 is the default int value, so skip
+ // encoding it to save two bytes in the output.
+ if (status_.code() != 0) {
+ rpc_packet.WriteStatus(status_.code());
+ }
return encoder.Encode();
}
@@ -96,7 +103,7 @@
// Packet type always takes two bytes to encode (varint key + varint enum).
reserved_size += 2;
- // Status field always takes two bytes to encode (varint key + varint status).
+ // Status field takes up to two bytes to encode (varint key + varint status).
reserved_size += 2;
// Payload field takes at least two bytes to encode (varint key + length).
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc
index 534fb87..88e7c02 100644
--- a/pw_rpc/packet_test.cc
+++ b/pw_rpc/packet_test.cc
@@ -55,11 +55,12 @@
100,
0,
0,
- 0,
+ 0
- // Status
- MakeKey(6, protobuf::WireType::kVarint),
- 0x00>();
+ // Status (not encoded if it is zero)
+ // MakeKey(6, protobuf::WireType::kVarint),
+ // 0x00
+ >();
// Test that a default-constructed packet sets its members to the default
// protobuf values.
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h
index 2004d70..8adc11e 100644
--- a/pw_rpc/public/pw_rpc/internal/packet.h
+++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -88,6 +88,9 @@
// Determines the space required to encode the packet proto fields for a
// response, excluding the payload. This may be used to split the buffer into
// reserved space and available space for the payload.
+ //
+ // This method allocates two bytes for the status. Status code 0 (OK) is not
+ // encoded since 0 is the default value.
size_t MinEncodedSizeBytes() const;
enum Destination : bool { kServer, kClient };
diff --git a/pw_rpc/public/pw_rpc/internal/responder.h b/pw_rpc/public/pw_rpc/internal/responder.h
index 26269d9..bfdf270 100644
--- a/pw_rpc/public/pw_rpc/internal/responder.h
+++ b/pw_rpc/public/pw_rpc/internal/responder.h
@@ -84,8 +84,6 @@
void Close();
- Packet ResponsePacket(std::span<const std::byte> payload = {}) const;
-
ServerCall call_;
Channel::OutputBuffer response_;
enum { kClosed, kOpen } state_;
diff --git a/pw_rpc/pw_rpc_private/fake_channel_output.h b/pw_rpc/pw_rpc_private/fake_channel_output.h
new file mode 100644
index 0000000..d46628e
--- /dev/null
+++ b/pw_rpc/pw_rpc_private/fake_channel_output.h
@@ -0,0 +1,59 @@
+// Copyright 2021 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include "pw_bytes/span.h"
+#include "pw_rpc/channel.h"
+
+namespace pw::rpc::internal::test {
+
+// A ChannelOutput implementation that stores the outgoing payloads and status.
+class FakeChannelOutput : public ChannelOutput {
+ public:
+ constexpr FakeChannelOutput(ByteSpan buffer, bool server_streaming)
+ : ChannelOutput("pw::rpc::internal::test::FakeChannelOutput"),
+ packet_buffer_(buffer),
+ server_streaming_(server_streaming) {}
+
+ Status last_status() const { return last_status_; }
+ void set_last_status(Status status) { last_status_ = status; }
+
+ size_t total_responses() const { return total_responses_; }
+
+ // Set to true if a RESPONSE packet is seen.
+ bool done() const { return done_; }
+
+ void clear();
+
+ private:
+ ByteSpan AcquireBuffer() final { return packet_buffer_; }
+
+ Status SendAndReleaseBuffer(ConstByteSpan buffer) final;
+
+ virtual void AppendResponse(ConstByteSpan response) = 0;
+ virtual void ClearResponses() = 0;
+
+ void ProcessResponse(ConstByteSpan response) {
+ AppendResponse(response);
+ total_responses_ += 1;
+ }
+
+ const ByteSpan packet_buffer_;
+ size_t total_responses_ = 0;
+ Status last_status_;
+ bool done_ = false;
+ const bool server_streaming_;
+};
+
+} // namespace pw::rpc::internal::test
diff --git a/pw_rpc/pw_rpc_private/internal_test_utils.h b/pw_rpc/pw_rpc_private/internal_test_utils.h
index 5dbdb11..11ffefa 100644
--- a/pw_rpc/pw_rpc_private/internal_test_utils.h
+++ b/pw_rpc/pw_rpc_private/internal_test_utils.h
@@ -101,14 +101,31 @@
server_.RegisterService(service_);
}
- // Creates a response packet for this context's channel, service, and method.
- internal::Packet packet(std::span<const std::byte> payload) const {
+ // Create packets for this context's channel, service, and method.
+ internal::Packet request(std::span<const std::byte> payload) const {
+ return internal::Packet(internal::PacketType::REQUEST,
+ kChannelId,
+ kServiceId,
+ context_.method().id(),
+ payload);
+ }
+
+ internal::Packet response(Status status,
+ std::span<const std::byte> payload = {}) const {
return internal::Packet(internal::PacketType::RESPONSE,
kChannelId,
kServiceId,
context_.method().id(),
payload,
- OkStatus());
+ status);
+ }
+
+ internal::Packet server_stream(std::span<const std::byte> payload) const {
+ return internal::Packet(internal::PacketType::SERVER_STREAM,
+ kChannelId,
+ kServiceId,
+ context_.method().id(),
+ payload);
}
internal::ServerCall& get() { return context_; }
@@ -156,10 +173,14 @@
return client_.ProcessPacket(result.value_or(ConstByteSpan()));
}
- Status SendResponse(Status status, std::span<const std::byte> payload) {
+ Status SendResponse(Status status, std::span<const std::byte> payload = {}) {
return SendPacket(internal::PacketType::RESPONSE, status, payload);
}
+ Status SendServerStream(std::span<const std::byte> payload) {
+ return SendPacket(internal::PacketType::SERVER_STREAM, OkStatus(), payload);
+ }
+
private:
TestOutput<kOutputBufferSize> output_;
Channel channel_;
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index e990bce..057fc3e 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -247,26 +247,30 @@
def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
- # Server streaming RPC packets never have a status; all other packets do.
- if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
+ if packet.type == PacketType.SERVER_STREAM:
return None
try:
return Status(packet.status)
except ValueError:
_LOG.warning('Illegal status code %d for %s', packet.status, rpc)
-
- return None
+ return Status.UNKNOWN
def _decode_payload(rpc: PendingRpc, packet):
- if packet.type == PacketType.RESPONSE:
- try:
- return packets.decode_payload(packet, rpc.method.response_type)
- except DecodeError as err:
- _LOG.warning('Failed to decode %s response for %s: %s',
- rpc.method.response_type.DESCRIPTOR.full_name,
- rpc.method.full_name, err)
+ if packet.type == PacketType.SERVER_ERROR:
+ return None
+
+ # Server streaming RPCs do not send a payload with their RESPONSE packet.
+ if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
+ return None
+
+ try:
+ return packets.decode_payload(packet, rpc.method.response_type)
+ except DecodeError as err:
+ _LOG.warning('Failed to decode %s response for %s: %s',
+ rpc.method.response_type.DESCRIPTOR.full_name,
+ rpc.method.full_name, err)
return None
@@ -326,6 +330,31 @@
f'services={[str(s) for s in self.services()]})')
+def _update_for_backwards_compatibility(rpc: PendingRpc,
+ packet: RpcPacket) -> None:
+ """Adapts server streaming RPC packets to the updated protocol if needed."""
+ # The protocol changes only affect server streaming RPCs.
+ if rpc.method.type is not Method.Type.SERVER_STREAMING:
+ return
+
+ # SERVER_STREAM_END packets are deprecated. They are equivalent to a
+ # RESPONSE packet.
+ if packet.type == PacketType.DEPRECATED_SERVER_STREAM_END:
+ packet.type = PacketType.RESPONSE
+ return
+
+ # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with
+ # a payload were used instead. If a non-zero payload is present, assume this
+ # RESPONSE is equivalent to a SERVER_STREAM packet.
+ #
+ # Note that the payload field is not 'optional', so an empty payload is
+ # equivalent to a payload that happens to encode to zero bytes. This would
+ # only affect server streaming RPCs on the old protocol that intentionally
+ # send empty payloads, which will not be an issue in practice.
+ if packet.type == PacketType.RESPONSE and packet.payload:
+ packet.type = PacketType.SERVER_STREAM
+
+
class Client:
"""Sends requests and handles responses for a set of channels.
@@ -424,15 +453,15 @@
_LOG.warning('%s', err)
return Status.OK
- status = _decode_status(rpc, packet)
+ _update_for_backwards_compatibility(rpc, packet)
- if packet.type not in (PacketType.RESPONSE,
- PacketType.SERVER_STREAM_END,
+ if packet.type not in (PacketType.RESPONSE, PacketType.SERVER_STREAM,
PacketType.SERVER_ERROR):
_LOG.error('%s: unexpected PacketType %s', rpc, packet.type)
_LOG.debug('Packet:\n%s', packet)
return Status.OK
+ status = _decode_status(rpc, packet)
payload = _decode_payload(rpc, packet)
try:
diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py
index d36b7a4..1980b25 100644
--- a/pw_rpc/py/pw_rpc/packets.py
+++ b/pw_rpc/py/pw_rpc/packets.py
@@ -19,7 +19,7 @@
from pw_rpc.internal import packet_pb2
-def decode(data: bytes):
+def decode(data: bytes) -> packet_pb2.RpcPacket:
packet = packet_pb2.RpcPacket()
packet.MergeFromString(data)
return packet
@@ -73,5 +73,5 @@
method_id=method).SerializeToString()
-def for_server(packet):
+def for_server(packet: packet_pb2.RpcPacket) -> bool:
return packet.type % 2 == 0
diff --git a/pw_rpc/py/tests/callback_client_test.py b/pw_rpc/py/tests/callback_client_test.py
index 0373a07..4c6009f 100755
--- a/pw_rpc/py/tests/callback_client_test.py
+++ b/pw_rpc/py/tests/callback_client_test.py
@@ -58,6 +58,10 @@
method_stub.method)
+def _message_bytes(msg) -> bytes:
+ return msg if isinstance(msg, bytes) else msg.SerializeToString()
+
+
class CallbackClientImplTest(unittest.TestCase):
"""Tests the callback_client as used within a pw_rpc Client."""
def setUp(self):
@@ -77,7 +81,7 @@
channel_id: int,
method=None,
status: Status = Status.OK,
- response=b'',
+ payload=b'',
*,
ids: Tuple[int, int] = None,
process_status=Status.OK):
@@ -88,32 +92,27 @@
assert ids is not None and method is None
service_id, method_id = ids
- if isinstance(response, bytes):
- payload = response
- else:
- payload = response.SerializeToString()
+ self._next_packets.append((packet_pb2.RpcPacket(
+ type=packet_pb2.PacketType.RESPONSE,
+ channel_id=channel_id,
+ service_id=service_id,
+ method_id=method_id,
+ status=status.value,
+ payload=_message_bytes(payload)).SerializeToString(),
+ process_status))
- self._next_packets.append(
- (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE,
- channel_id=channel_id,
- service_id=service_id,
- method_id=method_id,
- status=status.value,
- payload=payload).SerializeToString(),
- process_status))
-
- def _enqueue_stream_end(self,
- channel_id: int,
- method,
- status: Status = Status.OK,
- process_status=Status.OK):
- self._next_packets.append(
- (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END,
- channel_id=channel_id,
- service_id=method.service.id,
- method_id=method.id,
- status=status.value).SerializeToString(),
- process_status))
+ def _enqueue_server_stream(self,
+ channel_id: int,
+ method,
+ response,
+ process_status=Status.OK):
+ self._next_packets.append((packet_pb2.RpcPacket(
+ type=packet_pb2.PacketType.SERVER_STREAM,
+ channel_id=channel_id,
+ service_id=method.service.id,
+ method_id=method.id,
+ payload=_message_bytes(response)).SerializeToString(),
+ process_status))
def _enqueue_error(self,
channel_id: int,
@@ -288,9 +287,9 @@
rep2 = method.response_type(payload='?')
for _ in range(3):
- self._enqueue_response(1, method, response=rep1)
- self._enqueue_response(1, method, response=rep2)
- self._enqueue_stream_end(1, method, Status.ABORTED)
+ self._enqueue_server_stream(1, method, rep1)
+ self._enqueue_server_stream(1, method, rep2)
+ self._enqueue_response(1, method, Status.ABORTED)
self.assertEqual(
[rep1, rep2],
@@ -300,6 +299,35 @@
4,
self._sent_payload(method.request_type).magic_number)
+ def test_invoke_server_streaming_with_deprecated_packet_format(self):
+ method = self._service.SomeServerStreaming.method
+
+ rep1 = method.response_type(payload='!!!')
+ rep2 = method.response_type(payload='?')
+
+ for _ in range(3):
+ # The original packet format used RESPONSE packets for the server
+ # stream and a SERVER_STREAM_END packet as the last packet. These
+ # are converted to SERVER_STREAM packets followed by a RESPONSE.
+ self._enqueue_response(1, method, payload=rep1)
+ self._enqueue_response(1, method, payload=rep2)
+
+ self._next_packets.append((packet_pb2.RpcPacket(
+ type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END,
+ channel_id=1,
+ service_id=method.service.id,
+ method_id=method.id,
+ status=Status.INVALID_ARGUMENT.value).SerializeToString(),
+ Status.OK))
+
+ call = self._service.SomeServerStreaming(magic_number=4)
+ self.assertEqual([rep1, rep2], list(call))
+ self.assertIs(call.status, Status.INVALID_ARGUMENT)
+
+ self.assertEqual(
+ 4,
+ self._sent_payload(method.request_type).magic_number)
+
def test_invoke_server_streaming_with_callbacks(self):
method = self._service.SomeServerStreaming.method
@@ -307,9 +335,9 @@
rep2 = method.response_type(payload='?')
for _ in range(3):
- self._enqueue_response(1, method, response=rep1)
- self._enqueue_response(1, method, response=rep2)
- self._enqueue_stream_end(1, method, Status.ABORTED)
+ self._enqueue_server_stream(1, method, rep1)
+ self._enqueue_server_stream(1, method, rep2)
+ self._enqueue_response(1, method, Status.ABORTED)
callback = mock.Mock()
self._service.SomeServerStreaming.invoke(
@@ -330,7 +358,7 @@
stub = self._service.SomeServerStreaming
resp = stub.method.response_type(payload='!!!')
- self._enqueue_response(1, stub.method, response=resp)
+ self._enqueue_server_stream(1, stub.method, resp)
callback = mock.Mock()
call = stub.invoke(self._request(magic_number=3), callback)
@@ -344,8 +372,8 @@
self.assertEqual(self._last_request.type, packet_pb2.PacketType.CANCEL)
# Ensure the RPC can be called after being cancelled.
- self._enqueue_response(1, stub.method, response=resp)
- self._enqueue_stream_end(1, stub.method, Status.OK)
+ self._enqueue_server_stream(1, stub.method, resp)
+ self._enqueue_response(1, stub.method, Status.OK)
call = stub.invoke(self._request(magic_number=3), callback, callback)
diff --git a/pw_rpc/raw/BUILD.gn b/pw_rpc/raw/BUILD.gn
index 1f5d502..70c53cb 100644
--- a/pw_rpc/raw/BUILD.gn
+++ b/pw_rpc/raw/BUILD.gn
@@ -45,6 +45,7 @@
public = [ "public/pw_rpc/raw_test_method_context.h" ]
public_deps = [
":method",
+ "..:test_utils",
dir_pw_assert,
dir_pw_containers,
]
diff --git a/pw_rpc/raw/public/pw_rpc/internal/raw_method.h b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h
index 7d86ebe..4e63bee 100644
--- a/pw_rpc/raw/public/pw_rpc/internal/raw_method.h
+++ b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h
@@ -118,6 +118,8 @@
ServerContext&, ConstByteSpan, ByteSpan)> {
using Implementation = RawMethod;
static constexpr MethodType kType = MethodType::kUnary;
+ static constexpr bool kServerStreaming = false;
+ static constexpr bool kClientStreaming = false;
};
// MethodTraits specialization for a raw unary method.
@@ -126,6 +128,8 @@
ServerContext&, ConstByteSpan, ByteSpan)> {
using Implementation = RawMethod;
static constexpr MethodType kType = MethodType::kUnary;
+ static constexpr bool kServerStreaming = false;
+ static constexpr bool kClientStreaming = false;
using Service = T;
};
@@ -134,6 +138,8 @@
struct MethodTraits<void (*)(ServerContext&, ConstByteSpan, RawServerWriter&)> {
using Implementation = RawMethod;
static constexpr MethodType kType = MethodType::kServerStreaming;
+ static constexpr bool kServerStreaming = true;
+ static constexpr bool kClientStreaming = false;
};
// MethodTraits specialization for a raw server streaming method.
@@ -142,6 +148,8 @@
ServerContext&, ConstByteSpan, RawServerWriter&)> {
using Implementation = RawMethod;
static constexpr MethodType kType = MethodType::kServerStreaming;
+ static constexpr bool kServerStreaming = true;
+ static constexpr bool kClientStreaming = false;
using Service = T;
};
diff --git a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
index b9fc660..6ff8738 100644
--- a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
+++ b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
@@ -16,7 +16,6 @@
#include <type_traits>
#include "pw_assert/assert.h"
-#include "pw_bytes/span.h"
#include "pw_containers/vector.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/hash.h"
@@ -24,6 +23,7 @@
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/raw_method.h"
#include "pw_rpc/internal/server.h"
+#include "pw_rpc_private/fake_channel_output.h"
namespace pw::rpc {
@@ -79,7 +79,7 @@
::pw::rpc::internal::Hash(#method), \
##__VA_ARGS__>
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse = 4,
size_t kOutputSizeBytes = 128>
@@ -90,57 +90,50 @@
// A ChannelOutput implementation that stores the outgoing payloads and status.
template <size_t kOutputSize>
-class MessageOutput final : public ChannelOutput {
+class MessageOutput final : public FakeChannelOutput {
public:
using ResponseBuffer = std::array<std::byte, kOutputSize>;
MessageOutput(Vector<ByteSpan>& responses,
Vector<ResponseBuffer>& buffers,
- ByteSpan packet_buffer)
- : ChannelOutput("internal::test::raw::MessageOutput"),
+ ByteSpan packet_buffer,
+ bool server_streaming)
+ : FakeChannelOutput(packet_buffer, server_streaming),
responses_(responses),
- buffers_(buffers),
- packet_buffer_(packet_buffer) {
- clear();
- }
-
- Status last_status() const { return last_status_; }
- void set_last_status(Status status) { last_status_ = status; }
-
- size_t total_responses() const { return total_responses_; }
-
- bool stream_ended() const { return stream_ended_; }
-
- void clear() {
- responses_.clear();
- buffers_.clear();
- total_responses_ = 0;
- stream_ended_ = false;
- last_status_ = Status::Unknown();
- }
+ buffers_(buffers) {}
private:
- ByteSpan AcquireBuffer() override { return packet_buffer_; }
+ void AppendResponse(ConstByteSpan response) override {
+ // If we run out of space, the back message is always the most recent.
+ buffers_.emplace_back();
+ buffers_.back() = {};
+ std::memcpy(&buffers_.back(), response.data(), response.size());
+ responses_.emplace_back();
+ responses_.back() = {buffers_.back().data(), response.size()};
+ }
- Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override;
+ void ClearResponses() override {
+ responses_.clear();
+ buffers_.clear();
+ }
Vector<ByteSpan>& responses_;
Vector<ResponseBuffer>& buffers_;
- ByteSpan packet_buffer_;
- size_t total_responses_;
- bool stream_ended_;
- Status last_status_;
};
// Collects everything needed to invoke a particular RPC.
template <typename Service,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
struct InvocationContext {
template <typename... Args>
InvocationContext(Args&&... args)
- : output(responses, buffers, packet_buffer),
+ : output(responses,
+ buffers,
+ packet_buffer,
+ MethodTraits<decltype(kMethod)>::kServerStreaming),
channel(Channel::Create<123>(&output)),
server(std::span(&channel, 1)),
service(std::forward<Args>(args)...),
@@ -163,10 +156,14 @@
// Method invocation context for a unary RPC. Returns the status in call() and
// provides the response through the response() method.
-template <typename Service, auto method, uint32_t kMethodId, size_t kOutputSize>
+template <typename Service,
+ auto kMethod,
+ uint32_t kMethodId,
+ size_t kOutputSize>
class UnaryContext {
private:
- using Context = InvocationContext<Service, kMethodId, 1, kOutputSize>;
+ using Context =
+ InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize>;
Context ctx_;
public:
@@ -183,7 +180,7 @@
ctx_.responses.emplace_back();
auto& response = ctx_.responses.back();
response = {ctx_.buffers.back().data(), ctx_.buffers.back().size()};
- auto sws = CallMethodImplFunction<method>(ctx_.call, request, response);
+ auto sws = CallMethodImplFunction<kMethod>(ctx_.call, request, response);
response = response.first(sws.size());
return sws;
}
@@ -197,14 +194,14 @@
// Method invocation context for a server streaming RPC.
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
class ServerStreamingContext {
private:
using Context =
- InvocationContext<Service, kMethodId, kMaxResponse, kOutputSize>;
+ InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize>;
Context ctx_;
public:
@@ -217,7 +214,7 @@
void call(ConstByteSpan request) {
ctx_.output.clear();
Responder server_writer(ctx_.call);
- return CallMethodImplFunction<method>(
+ return CallMethodImplFunction<kMethod>(
ctx_.call, request, static_cast<RawServerWriter&>(server_writer));
}
@@ -239,7 +236,7 @@
size_t total_responses() const { return ctx_.output.total_responses(); }
// True if the stream has terminated.
- bool done() const { return ctx_.output.stream_ended(); }
+ bool done() const { return ctx_.output.done(); }
// The status of the stream. Only valid if done() is true.
Status status() const {
@@ -251,74 +248,41 @@
// Alias to select the type of the context object to use based on which type of
// RPC it is for.
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSize>
using Context = std::tuple_element_t<
- static_cast<size_t>(MethodTraits<decltype(method)>::kType),
- std::tuple<UnaryContext<Service, method, kMethodId, kOutputSize>,
+ static_cast<size_t>(MethodTraits<decltype(kMethod)>::kType),
+ std::tuple<UnaryContext<Service, kMethod, kMethodId, kOutputSize>,
ServerStreamingContext<Service,
- method,
+ kMethod,
kMethodId,
kMaxResponse,
kOutputSize>
// TODO(hepler): Support client and bidi streaming
>>;
-template <size_t kOutputSize>
-Status MessageOutput<kOutputSize>::SendAndReleaseBuffer(
- std::span<const std::byte> buffer) {
- PW_ASSERT(!stream_ended_);
- PW_ASSERT(buffer.data() == packet_buffer_.data());
-
- if (buffer.empty()) {
- return OkStatus();
- }
-
- Result<internal::Packet> result = internal::Packet::FromBuffer(buffer);
- PW_ASSERT(result.ok());
-
- last_status_ = result.value().status();
-
- switch (result.value().type()) {
- case internal::PacketType::RESPONSE: {
- // If we run out of space, the back message is always the most recent.
- buffers_.emplace_back();
- buffers_.back() = {};
- auto response = result.value().payload();
- std::memcpy(&buffers_.back(), response.data(), response.size());
- responses_.emplace_back();
- responses_.back() = {buffers_.back().data(), response.size()};
- total_responses_ += 1;
- break;
- }
- case internal::PacketType::SERVER_STREAM_END:
- stream_ended_ = true;
- break;
- default:
- pw_assert_HandleFailure(); // Unhandled PacketType
- }
- return OkStatus();
-}
-
} // namespace internal::test::raw
template <typename Service,
- auto method,
+ auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kOutputSizeBytes>
class RawTestMethodContext
: public internal::test::raw::
- Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes> {
+ Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> {
public:
// Forwards constructor arguments to the service class.
template <typename... ServiceArgs>
RawTestMethodContext(ServiceArgs&&... service_args)
- : internal::test::raw::
- Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes>(
- std::forward<ServiceArgs>(service_args)...) {}
+ : internal::test::raw::Context<Service,
+ kMethod,
+ kMethodId,
+ kMaxResponse,
+ kOutputSizeBytes>(
+ std::forward<ServiceArgs>(service_args)...) {}
};
} // namespace pw::rpc
diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc
index fd60781..a3df85d 100644
--- a/pw_rpc/raw/raw_method_test.cc
+++ b/pw_rpc/raw/raw_method_test.cc
@@ -149,7 +149,7 @@
const RawMethod& method = std::get<0>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet(encoder.Encode().value()));
+ method.Invoke(context.get(), context.request(encoder.Encode().value()));
EXPECT_EQ(last_request.integer, 456);
EXPECT_EQ(last_request.status_code, 7u);
@@ -174,7 +174,7 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet(encoder.Encode().value()));
+ method.Invoke(context.get(), context.request(encoder.Encode().value()));
EXPECT_EQ(0u, context.output().packet_count());
EXPECT_EQ(777, last_request.integer);
@@ -187,7 +187,7 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
auto buffer = last_writer.PayloadBuffer();
@@ -197,7 +197,7 @@
EXPECT_EQ(last_writer.Write(buffer.first(data.size())), OkStatus());
const internal::Packet& packet = context.output().sent_packet();
- EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE);
+ EXPECT_EQ(packet.type(), internal::PacketType::SERVER_STREAM);
EXPECT_EQ(packet.channel_id(), context.channel_id());
EXPECT_EQ(packet.service_id(), context.service_id());
EXPECT_EQ(packet.method_id(), context.get().method().id());
@@ -209,13 +209,13 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
EXPECT_EQ(last_writer.Write(data), OkStatus());
const internal::Packet& packet = context.output().sent_packet();
- EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE);
+ EXPECT_EQ(packet.type(), internal::PacketType::SERVER_STREAM);
EXPECT_EQ(packet.channel_id(), context.channel_id());
EXPECT_EQ(packet.service_id(), context.service_id());
EXPECT_EQ(packet.method_id(), context.get().method().id());
@@ -227,7 +227,7 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
EXPECT_EQ(OkStatus(), last_writer.Finish());
constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
@@ -238,7 +238,7 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService, 16> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
constexpr auto data =
bytes::Array<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16>();
@@ -250,7 +250,7 @@
const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
ServerContextForTest<FakeService> context(method);
- method.Invoke(context.get(), context.packet({}));
+ method.Invoke(context.get(), context.request({}));
{
RawServerWriter writer = std::move(last_writer);
@@ -261,7 +261,7 @@
auto output = context.output();
EXPECT_EQ(output.packet_count(), 1u);
- EXPECT_EQ(output.sent_packet().type(), PacketType::SERVER_STREAM_END);
+ EXPECT_EQ(output.sent_packet().type(), PacketType::RESPONSE);
}
} // namespace
diff --git a/pw_rpc/raw/raw_method_union_test.cc b/pw_rpc/raw/raw_method_union_test.cc
index f048366..aa30c7a 100644
--- a/pw_rpc/raw/raw_method_union_test.cc
+++ b/pw_rpc/raw/raw_method_union_test.cc
@@ -107,7 +107,7 @@
const Method& method =
std::get<1>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet(encoder.Encode().value()));
+ method.Invoke(context.get(), context.request(encoder.Encode().value()));
EXPECT_EQ(last_request.integer, 456);
EXPECT_EQ(last_request.status_code, 7u);
@@ -133,7 +133,7 @@
std::get<2>(FakeGeneratedServiceImpl::kMethods).method();
ServerContextForTest<FakeGeneratedServiceImpl> context(method);
- method.Invoke(context.get(), context.packet(encoder.Encode().value()));
+ method.Invoke(context.get(), context.request(encoder.Encode().value()));
EXPECT_EQ(0u, context.output().packet_count());
EXPECT_EQ(777, last_request.integer);
diff --git a/pw_rpc/responder.cc b/pw_rpc/responder.cc
index 1c20e8c..be9abb7 100644
--- a/pw_rpc/responder.cc
+++ b/pw_rpc/responder.cc
@@ -20,6 +20,30 @@
#include "pw_rpc/internal/server.h"
namespace pw::rpc::internal {
+namespace {
+
+Packet ResponsePacket(const ServerCall& call,
+ uint32_t method_id,
+ Status status) {
+ return Packet(PacketType::RESPONSE,
+ call.channel().id(),
+ call.service().id(),
+ method_id,
+ {},
+ status);
+}
+
+Packet StreamPacket(const ServerCall& call,
+ uint32_t method_id,
+ std::span<const std::byte> payload) {
+ return Packet(PacketType::SERVER_STREAM,
+ call.channel().id(),
+ call.service().id(),
+ method_id,
+ payload);
+}
+
+} // namespace
Responder::Responder(ServerCall& call) : call_(call), state_(kOpen) {
call_.server().RegisterResponder(*this);
@@ -58,13 +82,8 @@
Close();
- // Send a control packet indicating that the stream (and RPC) has terminated.
- return call_.channel().Send(Packet(PacketType::SERVER_STREAM_END,
- call_.channel().id(),
- call_.service().id(),
- method().id(),
- {},
- status));
+ // Send a packet indicating that the RPC has terminated.
+ return call_.channel().Send(ResponsePacket(call_, method_id(), status));
}
std::span<std::byte> Responder::AcquirePayloadBuffer() {
@@ -75,12 +94,13 @@
response_ = call_.channel().AcquireBuffer();
}
- return response_.payload(ResponsePacket());
+ return response_.payload(StreamPacket(call_, method_id(), {}));
}
Status Responder::ReleasePayloadBuffer(std::span<const std::byte> payload) {
PW_DCHECK(open());
- return call_.channel().Send(response_, ResponsePacket(payload));
+ return call_.channel().Send(response_,
+ StreamPacket(call_, method_id(), payload));
}
Status Responder::ReleasePayloadBuffer() {
@@ -98,12 +118,4 @@
state_ = kClosed;
}
-Packet Responder::ResponsePacket(std::span<const std::byte> payload) const {
- return Packet(PacketType::RESPONSE,
- call_.channel().id(),
- call_.service().id(),
- method().id(),
- payload);
-}
-
} // namespace pw::rpc::internal
diff --git a/pw_rpc/responder_test.cc b/pw_rpc/responder_test.cc
index 28b4910..8ae515d 100644
--- a/pw_rpc/responder_test.cc
+++ b/pw_rpc/responder_test.cc
@@ -112,14 +112,14 @@
EXPECT_TRUE(writers.empty());
}
-TEST(ServerWriter, Finish_SendsCancellationPacket) {
+TEST(ServerWriter, Finish_SendsResponse) {
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
EXPECT_EQ(OkStatus(), writer.Finish());
const Packet& packet = context.output().sent_packet();
- EXPECT_EQ(packet.type(), PacketType::SERVER_STREAM_END);
+ EXPECT_EQ(packet.type(), PacketType::RESPONSE);
EXPECT_EQ(packet.channel_id(), context.channel_id());
EXPECT_EQ(packet.service_id(), context.service_id());
EXPECT_EQ(packet.method_id(), context.get().method().id());
@@ -166,7 +166,7 @@
ASSERT_EQ(OkStatus(), writer.Write(data));
byte encoded[64];
- auto result = context.packet(data).Encode(encoded);
+ auto result = context.server_stream(data).Encode(encoded);
ASSERT_EQ(OkStatus(), result.status());
EXPECT_EQ(result.value().size(), context.output().sent_data().size());
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 1c9e813..c5e29b1 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -150,7 +150,7 @@
channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()));
PW_LOG_WARN("Received CANCEL packet for method that is not pending");
} else {
- writer->Finish(Status::Cancelled());
+ writer->Close();
}
}
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 8894b5a..f76168a 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -247,18 +247,12 @@
EXPECT_FALSE(writer_.open());
}
-TEST_F(MethodPending, ProcessPacket_Cancel_SendsStreamEndPacket) {
+TEST_F(MethodPending, ProcessPacket_Cancel_SendsNoResponse) {
EXPECT_EQ(OkStatus(),
server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100),
output_));
- const Packet& packet = output_.sent_packet();
- EXPECT_EQ(packet.type(), PacketType::SERVER_STREAM_END);
- EXPECT_EQ(packet.channel_id(), 1u);
- EXPECT_EQ(packet.service_id(), 42u);
- EXPECT_EQ(packet.method_id(), 100u);
- EXPECT_TRUE(packet.payload().empty());
- EXPECT_EQ(packet.status(), Status::Cancelled());
+ EXPECT_EQ(output_.packet_count(), 0u);
}
TEST_F(MethodPending,