pw_rpc: Update protocol for server streams Update the server-to-client RPC packet types so that the packet type unambiguously indicates whether it is the first or last packet. This is already the case for client-to-server RPC packet types. - Have RESPONSE always be the last packet in the stream. For RPCs without a server stream, it includes a payload. Remove SERVER_STREAM_END. - Introduce the SERVER_STREAM packet, to parallel the CLIENT_STREAM packet. - Update the server and client code and tests. Test that old-style streaming RPCs still work correctly. - Refactor the duplicate MessageOutput class into a FakeChannelOutput used by both Nanopb and raw RPCs. - In C++, don't encode default-valued payload and status fields. Change-Id: I218772dad6c2981dda5f032f298ea43ee5e08b4d Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/49822 Commit-Queue: Wyatt Hepler <hepler@google.com> Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com> Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/BUILD.bazel b/pw_rpc/BUILD.bazel index 82d9230..6cdf448 100644 --- a/pw_rpc/BUILD.bazel +++ b/pw_rpc/BUILD.bazel
@@ -44,14 +44,14 @@ pw_cc_library( name = "server", srcs = [ - "responder.cc", - "public/pw_rpc/internal/responder.h", "public/pw_rpc/internal/call.h", "public/pw_rpc/internal/hash.h", "public/pw_rpc/internal/method.h", "public/pw_rpc/internal/method_lookup.h", "public/pw_rpc/internal/method_union.h", + "public/pw_rpc/internal/responder.h", "public/pw_rpc/internal/server.h", + "responder.cc", "server.cc", "service.cc", ], @@ -115,8 +115,10 @@ pw_cc_library( name = "internal_test_utils", + srcs = ["fake_channel_output.cc"], hdrs = [ "public/pw_rpc/internal/test_method.h", + "pw_rpc_private/fake_channel_output.h", "pw_rpc_private/internal_test_utils.h", "pw_rpc_private/method_impl_tester.h", ], @@ -128,6 +130,8 @@ deps = [ ":client", ":server", + "//pw_assert", + "//pw_bytes", "//pw_span", ], )
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index d5f4fd5..25953c4 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn
@@ -127,13 +127,20 @@ pw_source_set("test_utils") { public = [ "public/pw_rpc/internal/test_method.h", + "pw_rpc_private/fake_channel_output.h", "pw_rpc_private/internal_test_utils.h", "pw_rpc_private/method_impl_tester.h", ] - public_configs = [ ":private_includes" ] + sources = [ "fake_channel_output.cc" ] + public_configs = [ + ":public_include_path", + ":private_includes", + ] public_deps = [ ":client", ":server", + dir_pw_assert, + dir_pw_bytes, ] visibility = [ "./*" ] }
diff --git a/pw_rpc/CMakeLists.txt b/pw_rpc/CMakeLists.txt index 4a34370..c1b9658 100644 --- a/pw_rpc/CMakeLists.txt +++ b/pw_rpc/CMakeLists.txt
@@ -72,8 +72,16 @@ pw_sync.mutex ) -add_library(pw_rpc.test_utils INTERFACE) -target_include_directories(pw_rpc.test_utils INTERFACE .) +pw_add_module_library(pw_rpc.test_utils + SOURCES + fake_channel_output.cc + PUBLIC_DEPS + pw_assert + pw_bytes + pw_rpc.client + pw_rpc.server +) +target_include_directories(pw_rpc.test_utils PUBLIC .) pw_proto_library(pw_rpc.protos SOURCES
diff --git a/pw_rpc/channel_test.cc b/pw_rpc/channel_test.cc index 7687a2d..9dc0c20 100644 --- a/pw_rpc/channel_test.cc +++ b/pw_rpc/channel_test.cc
@@ -37,10 +37,11 @@ EXPECT_EQ(nullptr, NameTester(nullptr).name()); } -constexpr Packet kTestPacket(PacketType::RESPONSE, 1, 42, 100); +constexpr Packet kTestPacket( + PacketType::RESPONSE, 1, 42, 100, {}, Status::NotFound()); const size_t kReservedSize = 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ + 2 /* payload key */ + - 2 /* status */; + 2 /* status (if not OK) */; enum class ChannelId { kOne = 1, @@ -67,7 +68,7 @@ } TEST(Channel, OutputBuffer_TooSmall) { - TestOutput<kReservedSize - 1> output; + TestOutput<kReservedSize - 2 /* payload key & size */ - 1> output; internal::Channel channel(100, &output); Channel::OutputBuffer output_buffer = channel.AcquireBuffer();
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc index b5f3f5d..b49a9a3 100644 --- a/pw_rpc/client.cc +++ b/pw_rpc/client.cc
@@ -70,10 +70,10 @@ case PacketType::RESPONSE: case PacketType::SERVER_ERROR: call->HandleResponse(packet); + call->Unregister(); break; - case PacketType::SERVER_STREAM_END: + case PacketType::SERVER_STREAM: call->HandleResponse(packet); - RemoveCall(*call); break; default: return Status::Unimplemented();
diff --git a/pw_rpc/docs.rst b/pw_rpc/docs.rst index 6ff8e43..1383159 100644 --- a/pw_rpc/docs.rst +++ b/pw_rpc/docs.rst
@@ -374,6 +374,15 @@ | | - payload | | | | +-------------------+-------------------------------------+ +| CLIENT_STREAM_END | Client stream is complete | +| | | +| | .. code-block:: text | +| | | +| | - channel_id | +| | - service_id | +| | - method_id | +| | | ++-------------------+-------------------------------------+ | CLIENT_ERROR | Received unexpected packet | | | | | | .. code-block:: text | @@ -393,15 +402,6 @@ | | - method_id | | | | +-------------------+-------------------------------------+ -| CLIENT_STREAM_END | Client stream is complete | -| | | -| | .. code-block:: text | -| | | -| | - channel_id | -| | - service_id | -| | - method_id | -| | | -+-------------------+-------------------------------------+ **Errors** @@ -421,7 +421,19 @@ +-------------------+-------------------------------------+ | packet type | description | +===================+=====================================+ -| RESPONSE | RPC response | +| RESPONSE | The RPC is complete | +| | | +| | .. code-block:: text | +| | | +| | - channel_id | +| | - service_id | +| | - method_id | +| | - status | +| | - payload | +| | (unary & client streaming only) | +| | | ++-------------------+-------------------------------------+ +| SERVER_STREAM | Message in a server stream | | | | | | .. code-block:: text | | | | @@ -429,18 +441,6 @@ | | - service_id | | | - method_id | | | - payload | -| | - status | -| | (unary & client streaming only) | -| | | -+-------------------+-------------------------------------+ -| SERVER_STREAM_END | Server stream and RPC finished | -| | | -| | .. code-block:: text | -| | | -| | - channel_id | -| | - service_id | -| | - method_id | -| | - status | | | | +-------------------+-------------------------------------+ | SERVER_ERROR | Received unexpected packet | @@ -474,6 +474,21 @@ the protocol for calling service methods of each type: unary, server streaming, client streaming, and bidirectional streaming. +The basic flow for all RPC invocations is as follows: + + * Client sends a ``REQUEST`` packet. Includes a payload for unary & server + streaming RPCs. + * For client and bidirectional streaming RPCs, the client may send any number + of ``CLIENT_STREAM`` packets with payloads. + * For server and bidirectional streaming RPCs, the server may send any number + of ``SERVER_STREAM`` packets. + * The server sends a ``RESPONSE`` packet. Includes a payload for unary & + client streaming RPCs. The RPC is complete. + +The client may cancel an ongoing RPC at any time by sending a ``CANCEL`` packet. +The server may finish an ongoing RPC at any time by sending the ``RESPONSE`` +packet. + Unary RPC ^^^^^^^^^ In a unary RPC, the client sends a single request and the server sends a single @@ -522,7 +537,7 @@ Server streaming RPC ^^^^^^^^^^^^^^^^^^^^ In a server streaming RPC, the client sends a single request and the server -sends any number of responses followed by a ``SERVER_STREAM_END`` packet. +sends any number of ``SERVER_STREAM`` packets followed by a ``RESPONSE`` packet. .. seqdiag:: :scale: 110 @@ -538,12 +553,12 @@ client <-- server [ noactivate, label = "messages (zero or more)", - rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload" + rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload" ]; client <- server [ label = "done", - rightnote = "PacketType.SERVER_STREAM_END\nchannel ID\nservice ID\nmethod ID\nstatus" + rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\nstatus" ]; } @@ -564,7 +579,7 @@ client <-- server [ noactivate, label = "messages (zero or more)", - rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload" + rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload" ]; client -> server [ @@ -579,7 +594,7 @@ In a client streaming RPC, the client starts the RPC by sending a ``REQUEST`` packet with no payload. It then sends any number of messages in ``CLIENT_STREAM`` packets, followed by a ``CLIENT_STREAM_END``. The server sends -a single response to finish the RPC. +a single ``RESPONSE`` to finish the RPC. .. seqdiag:: :scale: 110 @@ -643,8 +658,8 @@ In a bidirectional streaming RPC, the client sends any number of requests and the server sends any number of responses. The client invokes the RPC by sending a ``REQUEST`` with no payload. It sends a ``CLIENT_STREAM_END`` packet when it -has finished sending requests. The server sends a ``SERVER_STREAM_END`` packet -to finish the RPC. +has finished sending requests. The server sends a ``RESPONSE`` packet to finish +the RPC. .. seqdiag:: :scale: 110 @@ -668,7 +683,7 @@ client <-- server [ noactivate, label = "messages (zero or more)", - rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload" + rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload" ]; client -> server [ @@ -679,13 +694,13 @@ client <- server [ label = "done", - rightnote = "PacketType.SERVER_STREAM_END\nchannel ID\nservice ID\nmethod ID\nstatus" + rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\nstatus" ]; } -The server may finish the RPC at any time by sending the ``SERVER_STREAM_END`` -packet, even if it has not received the ``CLIENT_STREAM_END`` packet. The client -may terminate the RPC at any time by sending a ``CANCEL`` packet. +The server may finish the RPC at any time by sending the ``RESPONSE`` packet, +even if it has not received the ``CLIENT_STREAM_END`` packet. The client may +terminate the RPC at any time by sending a ``CANCEL`` packet. .. seqdiag:: :scale: 110 @@ -707,13 +722,14 @@ client <-- server [ noactivate, label = "messages (zero or more)", - rightnote = "PacketType.RESPONSE\nchannel ID\nservice ID\nmethod ID\npayload" + rightnote = "PacketType.SERVER_STREAM\nchannel ID\nservice ID\nmethod ID\npayload" ]; client -> server [ noactivate, label = "cancel", - leftnote = "PacketType.CANCEL\nchannel ID\nservice ID\nmethod ID" ]; + leftnote = "PacketType.CANCEL\nchannel ID\nservice ID\nmethod ID" + ]; } RPC server
diff --git a/pw_rpc/fake_channel_output.cc b/pw_rpc/fake_channel_output.cc new file mode 100644 index 0000000..e68a879 --- /dev/null +++ b/pw_rpc/fake_channel_output.cc
@@ -0,0 +1,62 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc_private/fake_channel_output.h" + +#include "pw_assert/check.h" +#include "pw_result/result.h" +#include "pw_rpc/internal/packet.h" + +namespace pw::rpc::internal::test { + +void FakeChannelOutput::clear() { + ClearResponses(); + total_responses_ = 0; + last_status_ = Status::Unknown(); + done_ = false; +} + +Status FakeChannelOutput::SendAndReleaseBuffer( + std::span<const std::byte> buffer) { + PW_CHECK(!done_); + PW_CHECK_PTR_EQ(buffer.data(), packet_buffer_.data()); + + if (buffer.empty()) { + return OkStatus(); + } + + Result<Packet> result = Packet::FromBuffer(buffer); + PW_CHECK_OK(result.status()); + + last_status_ = result.value().status(); + + switch (result.value().type()) { + case PacketType::RESPONSE: + // Server streaming RPCs don't have a payload in their response packet. + if (!server_streaming_) { + ProcessResponse(result.value().payload()); + } + done_ = true; + break; + case PacketType::SERVER_STREAM: + ProcessResponse(result.value().payload()); + break; + default: + PW_CRASH("Unhandled PacketType %d", + static_cast<int>(result.value().type())); + } + return OkStatus(); +} + +} // namespace pw::rpc::internal::test
diff --git a/pw_rpc/internal/packet.proto b/pw_rpc/internal/packet.proto index 4438247..ebcc93e 100644 --- a/pw_rpc/internal/packet.proto +++ b/pw_rpc/internal/packet.proto
@@ -23,10 +23,11 @@ // Client-to-server packets - // A request from a client for a service method. + // The client invokes an RPC. Always the first packet. REQUEST = 0; - // A message in a client stream. + // A message in a client stream. Always sent after a REQUEST and before a + // CLIENT_STREAM_END. CLIENT_STREAM = 2; // The client received a packet for an RPC it did not request. @@ -40,14 +41,18 @@ // Server-to-client packets - // A response from a server for a service method. + // The RPC has finished. RESPONSE = 1; - // A server streaming or bidirectional RPC has completed. - SERVER_STREAM_END = 3; + // Deprecated, do not use. Formerly was used as the last packet in a server + // stream. + DEPRECATED_SERVER_STREAM_END = 3; // The server was unable to process a request. SERVER_ERROR = 5; + + // A message in a server stream. + SERVER_STREAM = 7; } message RpcPacket {
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn index 72cc611..17e11b9 100644 --- a/pw_rpc/nanopb/BUILD.gn +++ b/pw_rpc/nanopb/BUILD.gn
@@ -72,6 +72,7 @@ public = [ "public/pw_rpc/nanopb_test_method_context.h" ] public_deps = [ ":method", + "..:test_utils", dir_pw_assert, dir_pw_containers, ]
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc index 12ea706..36f0a4c 100644 --- a/pw_rpc/nanopb/codegen_test.cc +++ b/pw_rpc/nanopb/codegen_test.cc
@@ -215,12 +215,11 @@ PW_ENCODE_PB( pw_rpc_test_TestStreamResponse, response, .chunk = {}, .number = 11u); - context.SendResponse(OkStatus(), response); + context.SendServerStream(response); EXPECT_TRUE(result.active); EXPECT_EQ(result.response_value, 11); - context.SendPacket(internal::PacketType::SERVER_STREAM_END, - Status::NotFound()); + context.SendResponse(Status::NotFound()); EXPECT_FALSE(result.active); EXPECT_EQ(result.stream_status, Status::NotFound()); }
diff --git a/pw_rpc/nanopb/nanopb_client_call_test.cc b/pw_rpc/nanopb/nanopb_client_call_test.cc index ddac23b..c9d87d2 100644 --- a/pw_rpc/nanopb/nanopb_client_call_test.cc +++ b/pw_rpc/nanopb/nanopb_client_call_test.cc
@@ -248,19 +248,19 @@ }); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u); - context.SendResponse(OkStatus(), r1); + context.SendServerStream(r1); EXPECT_TRUE(active_); EXPECT_EQ(responses_received_, 1); EXPECT_EQ(last_response_number_, 11); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u); - context.SendResponse(OkStatus(), r2); + context.SendServerStream(r2); EXPECT_TRUE(active_); EXPECT_EQ(responses_received_, 2); EXPECT_EQ(last_response_number_, 22); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u); - context.SendResponse(OkStatus(), r3); + context.SendServerStream(r3); EXPECT_TRUE(active_); EXPECT_EQ(responses_received_, 3); EXPECT_EQ(last_response_number_, 33); @@ -283,19 +283,18 @@ }); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u); - context.SendResponse(OkStatus(), r1); + context.SendServerStream(r1); EXPECT_TRUE(active_); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u); - context.SendResponse(OkStatus(), r2); + context.SendServerStream(r2); EXPECT_TRUE(active_); // Close the stream. - context.SendPacket(internal::PacketType::SERVER_STREAM_END, - Status::NotFound()); + context.SendResponse(Status::NotFound()); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u); - context.SendResponse(OkStatus(), r3); + context.SendServerStream(r3); EXPECT_FALSE(active_); EXPECT_EQ(responses_received_, 2); @@ -316,19 +315,19 @@ [this](Status error) { rpc_error_ = error; }); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u); - context.SendResponse(OkStatus(), r1); + context.SendServerStream(r1); EXPECT_TRUE(active_); EXPECT_EQ(responses_received_, 1); EXPECT_EQ(last_response_number_, 11); constexpr std::byte bad_payload[]{ std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}}; - context.SendResponse(OkStatus(), bad_payload); + context.SendServerStream(bad_payload); EXPECT_EQ(responses_received_, 1); EXPECT_EQ(rpc_error_, Status::DataLoss()); PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u); - context.SendResponse(OkStatus(), r2); + context.SendServerStream(r2); EXPECT_TRUE(active_); EXPECT_EQ(responses_received_, 2); EXPECT_EQ(last_response_number_, 22);
diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc index d38c461..d212cba 100644 --- a/pw_rpc/nanopb/nanopb_method_test.cc +++ b/pw_rpc/nanopb/nanopb_method_test.cc
@@ -168,7 +168,7 @@ const NanopbMethod& method = std::get<1>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); const Packet& response = context.output().sent_packet(); EXPECT_EQ(response.status(), Status::Unauthenticated()); @@ -190,7 +190,7 @@ const NanopbMethod& method = std::get<0>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet(bad_payload)); + method.Invoke(context.get(), context.request(bad_payload)); const Packet& packet = context.output().sent_packet(); EXPECT_EQ(PacketType::SERVER_ERROR, packet.type()); @@ -208,10 +208,11 @@ std::get<1>(FakeService::kMethods).nanopb_method(); // Output buffer is too small for the response, but can fit an error packet. ServerContextForTest<FakeService, 22> context(method); - ASSERT_LT(context.output().buffer_size(), - context.packet(request).MinEncodedSizeBytes() + request.size() + 1); + ASSERT_LT( + context.output().buffer_size(), + context.request(request).MinEncodedSizeBytes() + request.size() + 1); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); const Packet& packet = context.output().sent_packet(); EXPECT_EQ(PacketType::SERVER_ERROR, packet.type()); @@ -230,7 +231,7 @@ std::get<2>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); EXPECT_EQ(0u, context.output().packet_count()); EXPECT_EQ(555, last_request.integer); @@ -241,13 +242,13 @@ std::get<2>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), last_writer.Write({.value = 100})); PW_ENCODE_PB(pw_rpc_test_TestResponse, payload, .value = 100); std::array<byte, 128> encoded_response = {}; - auto encoded = context.packet(payload).Encode(encoded_response); + auto encoded = context.server_stream(payload).Encode(encoded_response); ASSERT_EQ(OkStatus(), encoded.status()); ASSERT_EQ(encoded.value().size(), context.output().sent_data().size()); @@ -262,7 +263,7 @@ std::get<2>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), last_writer.Finish()); EXPECT_TRUE(last_writer.Write({.value = 100}).IsFailedPrecondition()); @@ -273,7 +274,7 @@ std::get<2>(FakeService::kMethods).nanopb_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); ServerWriter<pw_rpc_test_TestResponse> new_writer = std::move(last_writer); EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100})); @@ -289,20 +290,20 @@ const NanopbMethod& method = std::get<2>(FakeService::kMethods).nanopb_method(); - constexpr size_t kNoPayloadPacketSize = 2 /* type */ + 2 /* channel */ + - 5 /* service */ + 5 /* method */ + - 2 /* payload */ + 2 /* status */; + constexpr size_t kNoPayloadPacketSize = + 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ + + 0 /* payload (when empty) */ + 0 /* status (when OK)*/; // Make the buffer barely fit a packet with no payload. ServerContextForTest<FakeService, kNoPayloadPacketSize> context(method); // Verify that the encoded size of a packet with an empty payload is correct. std::array<byte, 128> encoded_response = {}; - auto encoded = context.packet({}).Encode(encoded_response); + auto encoded = context.request({}).Encode(encoded_response); ASSERT_EQ(OkStatus(), encoded.status()); ASSERT_EQ(kNoPayloadPacketSize, encoded.value().size()); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), last_writer.Write({})); // Barely fits EXPECT_EQ(Status::Internal(), last_writer.Write({.value = 1})); // Too big
diff --git a/pw_rpc/nanopb/nanopb_method_union_test.cc b/pw_rpc/nanopb/nanopb_method_union_test.cc index d1bfd7f..cb3ef71 100644 --- a/pw_rpc/nanopb/nanopb_method_union_test.cc +++ b/pw_rpc/nanopb/nanopb_method_union_test.cc
@@ -92,7 +92,7 @@ const Method& method = std::get<0>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); const Packet& response = context.output().sent_packet(); EXPECT_EQ(response.status(), Status::Unknown()); @@ -106,12 +106,11 @@ std::get<1>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); EXPECT_TRUE(last_raw_writer.open()); EXPECT_EQ(OkStatus(), last_raw_writer.Finish()); - EXPECT_EQ(context.output().sent_packet().type(), - PacketType::SERVER_STREAM_END); + EXPECT_EQ(context.output().sent_packet().type(), PacketType::RESPONSE); } TEST(NanopbMethodUnion, Nanopb_CallsUnaryMethod) { @@ -121,7 +120,7 @@ const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); const Packet& response = context.output().sent_packet(); EXPECT_EQ(response.status(), Status::Unauthenticated()); @@ -146,14 +145,13 @@ std::get<3>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet(request)); + method.Invoke(context.get(), context.request(request)); EXPECT_EQ(555, last_request.integer); EXPECT_TRUE(last_writer.open()); EXPECT_EQ(OkStatus(), last_writer.Finish()); - EXPECT_EQ(context.output().sent_packet().type(), - PacketType::SERVER_STREAM_END); + EXPECT_EQ(context.output().sent_packet().type(), PacketType::RESPONSE); } } // namespace
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h index 1383dc6..e77db5f 100644 --- a/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h +++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_client_call.h
@@ -200,8 +200,6 @@ } else { callbacks_.InvokeRpcError(Status::DataLoss()); } - - Unregister(); } void InvokeServerStreamingCallback(const internal::Packet& packet) { @@ -210,12 +208,15 @@ return; } - if (packet.type() == internal::PacketType::SERVER_STREAM_END && - callbacks_.stream_end) { - callbacks_.stream_end(packet.status()); + if (packet.type() == internal::PacketType::RESPONSE) { + // The server sent the last packet; call the stream end callback. + if (callbacks_.stream_end) { + callbacks_.stream_end(packet.status()); + } return; } + // Handle internal::PacketType::SERVER_STREAM packets. ResponseBuffer response_struct{}; if (callbacks_.stream_response &&
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h index b65ea22..5d836ec 100644 --- a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h +++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
@@ -17,6 +17,7 @@ #include <utility> #include "pw_assert/assert.h" +#include "pw_bytes/span.h" #include "pw_containers/vector.h" #include "pw_preprocessor/arguments.h" #include "pw_rpc/channel.h" @@ -25,6 +26,7 @@ #include "pw_rpc/internal/nanopb_method.h" #include "pw_rpc/internal/packet.h" #include "pw_rpc/internal/server.h" +#include "pw_rpc_private/fake_channel_output.h" namespace pw::rpc { @@ -77,7 +79,7 @@ ::pw::rpc::internal::Hash(#method), \ ##__VA_ARGS__> template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse = 4, size_t kOutputSizeBytes = 128> @@ -88,55 +90,46 @@ // A ChannelOutput implementation that stores the outgoing payloads and status. template <typename Response> -class MessageOutput final : public ChannelOutput { +class MessageOutput final : public FakeChannelOutput { public: - MessageOutput(const internal::NanopbMethod& method, + MessageOutput(const internal::NanopbMethod& kMethod, Vector<Response>& responses, - std::span<std::byte> buffer) - : ChannelOutput("internal::test::nanopb::MessageOutput"), - method_(method), - responses_(responses), - buffer_(buffer) { - clear(); - } - - Status last_status() const { return last_status_; } - void set_last_status(Status status) { last_status_ = status; } - - size_t total_responses() const { return total_responses_; } - - bool stream_ended() const { return stream_ended_; } - - void clear(); + ByteSpan packet_buffer, + bool server_streaming) + : FakeChannelOutput(packet_buffer, server_streaming), + method_(kMethod), + responses_(responses) {} private: - std::span<std::byte> AcquireBuffer() override { return buffer_; } + void AppendResponse(ConstByteSpan response) override { + // If we run out of space, the back message is always the most recent. + responses_.emplace_back(); + responses_.back() = {}; + PW_ASSERT(method_.DecodeResponse(response, &responses_.back())); + } - Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override; + void ClearResponses() override { responses_.clear(); } const internal::NanopbMethod& method_; Vector<Response>& responses_; - std::span<std::byte> buffer_; - size_t total_responses_; - bool stream_ended_; - Status last_status_; }; // Collects everything needed to invoke a particular RPC. template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> struct InvocationContext { - using Request = internal::Request<method>; - using Response = internal::Response<method>; + using Request = internal::Request<kMethod>; + using Response = internal::Response<kMethod>; template <typename... Args> InvocationContext(Args&&... args) : output(MethodLookup::GetNanopbMethod<Service, kMethodId>(), responses, - buffer), + buffer, + MethodTraits<decltype(kMethod)>::kServerStreaming), channel(Channel::Create<123>(&output)), server(std::span(&channel, 1)), service(std::forward<Args>(args)...), @@ -158,10 +151,13 @@ // Method invocation context for a unary RPC. Returns the status in call() and // provides the response through the response() method. -template <typename Service, auto method, uint32_t kMethodId, size_t kOutputSize> +template <typename Service, + auto kMethod, + uint32_t kMethodId, + size_t kOutputSize> class UnaryContext { private: - InvocationContext<Service, method, kMethodId, 1, kOutputSize> ctx_; + InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize> ctx_; public: using Request = typename decltype(ctx_)::Request; @@ -177,7 +173,7 @@ ctx_.output.clear(); ctx_.responses.emplace_back(); ctx_.responses.back() = {}; - return CallMethodImplFunction<method>( + return CallMethodImplFunction<kMethod>( ctx_.call, request, ctx_.responses.back()); } @@ -190,13 +186,14 @@ // Method invocation context for a server streaming RPC. template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> class ServerStreamingContext { private: - InvocationContext<Service, method, kMethodId, kMaxResponse, kOutputSize> ctx_; + InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize> + ctx_; public: using Request = typename decltype(ctx_)::Request; @@ -211,7 +208,7 @@ void call(const Request& request) { ctx_.output.clear(); internal::Responder server_writer(ctx_.call); - return CallMethodImplFunction<method>( + return CallMethodImplFunction<kMethod>( ctx_.call, request, static_cast<ServerWriter<Response>&>(server_writer)); @@ -235,7 +232,7 @@ size_t total_responses() const { return ctx_.output.total_responses(); } // True if the stream has terminated. - bool done() const { return ctx_.output.stream_ended(); } + bool done() const { return ctx_.output.done(); } // The status of the stream. Only valid if done() is true. Status status() const { @@ -247,79 +244,41 @@ // Alias to select the type of the context object to use based on which type of // RPC it is for. template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> using Context = std::tuple_element_t< - static_cast<size_t>(internal::MethodTraits<decltype(method)>::kType), - std::tuple<UnaryContext<Service, method, kMethodId, kOutputSize>, + static_cast<size_t>(internal::MethodTraits<decltype(kMethod)>::kType), + std::tuple<UnaryContext<Service, kMethod, kMethodId, kOutputSize>, ServerStreamingContext<Service, - method, + kMethod, kMethodId, kMaxResponse, kOutputSize> // TODO(hepler): Support client and bidi streaming >>; -template <typename Response> -void MessageOutput<Response>::clear() { - responses_.clear(); - total_responses_ = 0; - stream_ended_ = false; - last_status_ = Status::Unknown(); -} - -template <typename Response> -Status MessageOutput<Response>::SendAndReleaseBuffer( - std::span<const std::byte> buffer) { - PW_ASSERT(!stream_ended_); - PW_ASSERT(buffer.data() == buffer_.data()); - - if (buffer.empty()) { - return OkStatus(); - } - - Result<internal::Packet> result = internal::Packet::FromBuffer(buffer); - PW_ASSERT(result.ok()); - - last_status_ = result.value().status(); - - switch (result.value().type()) { - case internal::PacketType::RESPONSE: - // If we run out of space, the back message is always the most recent. - responses_.emplace_back(); - responses_.back() = {}; - PW_ASSERT( - method_.DecodeResponse(result.value().payload(), &responses_.back())); - total_responses_ += 1; - break; - case internal::PacketType::SERVER_STREAM_END: - stream_ended_ = true; - break; - default: - pw_assert_HandleFailure(); // Unhandled PacketType - } - return OkStatus(); -} - } // namespace internal::test::nanopb template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSizeBytes> class NanopbTestMethodContext : public internal::test::nanopb:: - Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes> { + Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> { public: // Forwards constructor arguments to the service class. template <typename... ServiceArgs> NanopbTestMethodContext(ServiceArgs&&... service_args) - : internal::test::nanopb:: - Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes>( - std::forward<ServiceArgs>(service_args)...) {} + : internal::test::nanopb::Context<Service, + kMethod, + kMethodId, + kMaxResponse, + kOutputSizeBytes>( + std::forward<ServiceArgs>(service_args)...) {} }; } // namespace pw::rpc
diff --git a/pw_rpc/packet.cc b/pw_rpc/packet.cc index 3739370..275acee 100644 --- a/pw_rpc/packet.cc +++ b/pw_rpc/packet.cc
@@ -74,13 +74,20 @@ RpcPacket::Encoder rpc_packet(&encoder); // The payload is encoded first, as it may share the encode buffer. - rpc_packet.WritePayload(payload_); + if (!payload_.empty()) { + rpc_packet.WritePayload(payload_); + } rpc_packet.WriteType(type_); rpc_packet.WriteChannelId(channel_id_); rpc_packet.WriteServiceId(service_id_); rpc_packet.WriteMethodId(method_id_); - rpc_packet.WriteStatus(status_.code()); + + // Status code 0 is OK. In protobufs, 0 is the default int value, so skip + // encoding it to save two bytes in the output. + if (status_.code() != 0) { + rpc_packet.WriteStatus(status_.code()); + } return encoder.Encode(); } @@ -96,7 +103,7 @@ // Packet type always takes two bytes to encode (varint key + varint enum). reserved_size += 2; - // Status field always takes two bytes to encode (varint key + varint status). + // Status field takes up to two bytes to encode (varint key + varint status). reserved_size += 2; // Payload field takes at least two bytes to encode (varint key + length).
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc index 534fb87..88e7c02 100644 --- a/pw_rpc/packet_test.cc +++ b/pw_rpc/packet_test.cc
@@ -55,11 +55,12 @@ 100, 0, 0, - 0, + 0 - // Status - MakeKey(6, protobuf::WireType::kVarint), - 0x00>(); + // Status (not encoded if it is zero) + // MakeKey(6, protobuf::WireType::kVarint), + // 0x00 + >(); // Test that a default-constructed packet sets its members to the default // protobuf values.
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h index 2004d70..8adc11e 100644 --- a/pw_rpc/public/pw_rpc/internal/packet.h +++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -88,6 +88,9 @@ // Determines the space required to encode the packet proto fields for a // response, excluding the payload. This may be used to split the buffer into // reserved space and available space for the payload. + // + // This method allocates two bytes for the status. Status code 0 (OK) is not + // encoded since 0 is the default value. size_t MinEncodedSizeBytes() const; enum Destination : bool { kServer, kClient };
diff --git a/pw_rpc/public/pw_rpc/internal/responder.h b/pw_rpc/public/pw_rpc/internal/responder.h index 26269d9..bfdf270 100644 --- a/pw_rpc/public/pw_rpc/internal/responder.h +++ b/pw_rpc/public/pw_rpc/internal/responder.h
@@ -84,8 +84,6 @@ void Close(); - Packet ResponsePacket(std::span<const std::byte> payload = {}) const; - ServerCall call_; Channel::OutputBuffer response_; enum { kClosed, kOpen } state_;
diff --git a/pw_rpc/pw_rpc_private/fake_channel_output.h b/pw_rpc/pw_rpc_private/fake_channel_output.h new file mode 100644 index 0000000..d46628e --- /dev/null +++ b/pw_rpc/pw_rpc_private/fake_channel_output.h
@@ -0,0 +1,59 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include "pw_bytes/span.h" +#include "pw_rpc/channel.h" + +namespace pw::rpc::internal::test { + +// A ChannelOutput implementation that stores the outgoing payloads and status. +class FakeChannelOutput : public ChannelOutput { + public: + constexpr FakeChannelOutput(ByteSpan buffer, bool server_streaming) + : ChannelOutput("pw::rpc::internal::test::FakeChannelOutput"), + packet_buffer_(buffer), + server_streaming_(server_streaming) {} + + Status last_status() const { return last_status_; } + void set_last_status(Status status) { last_status_ = status; } + + size_t total_responses() const { return total_responses_; } + + // Set to true if a RESPONSE packet is seen. + bool done() const { return done_; } + + void clear(); + + private: + ByteSpan AcquireBuffer() final { return packet_buffer_; } + + Status SendAndReleaseBuffer(ConstByteSpan buffer) final; + + virtual void AppendResponse(ConstByteSpan response) = 0; + virtual void ClearResponses() = 0; + + void ProcessResponse(ConstByteSpan response) { + AppendResponse(response); + total_responses_ += 1; + } + + const ByteSpan packet_buffer_; + size_t total_responses_ = 0; + Status last_status_; + bool done_ = false; + const bool server_streaming_; +}; + +} // namespace pw::rpc::internal::test
diff --git a/pw_rpc/pw_rpc_private/internal_test_utils.h b/pw_rpc/pw_rpc_private/internal_test_utils.h index 5dbdb11..11ffefa 100644 --- a/pw_rpc/pw_rpc_private/internal_test_utils.h +++ b/pw_rpc/pw_rpc_private/internal_test_utils.h
@@ -101,14 +101,31 @@ server_.RegisterService(service_); } - // Creates a response packet for this context's channel, service, and method. - internal::Packet packet(std::span<const std::byte> payload) const { + // Create packets for this context's channel, service, and method. + internal::Packet request(std::span<const std::byte> payload) const { + return internal::Packet(internal::PacketType::REQUEST, + kChannelId, + kServiceId, + context_.method().id(), + payload); + } + + internal::Packet response(Status status, + std::span<const std::byte> payload = {}) const { return internal::Packet(internal::PacketType::RESPONSE, kChannelId, kServiceId, context_.method().id(), payload, - OkStatus()); + status); + } + + internal::Packet server_stream(std::span<const std::byte> payload) const { + return internal::Packet(internal::PacketType::SERVER_STREAM, + kChannelId, + kServiceId, + context_.method().id(), + payload); } internal::ServerCall& get() { return context_; } @@ -156,10 +173,14 @@ return client_.ProcessPacket(result.value_or(ConstByteSpan())); } - Status SendResponse(Status status, std::span<const std::byte> payload) { + Status SendResponse(Status status, std::span<const std::byte> payload = {}) { return SendPacket(internal::PacketType::RESPONSE, status, payload); } + Status SendServerStream(std::span<const std::byte> payload) { + return SendPacket(internal::PacketType::SERVER_STREAM, OkStatus(), payload); + } + private: TestOutput<kOutputBufferSize> output_; Channel channel_;
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py index e990bce..057fc3e 100644 --- a/pw_rpc/py/pw_rpc/client.py +++ b/pw_rpc/py/pw_rpc/client.py
@@ -247,26 +247,30 @@ def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]: - # Server streaming RPC packets never have a status; all other packets do. - if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: + if packet.type == PacketType.SERVER_STREAM: return None try: return Status(packet.status) except ValueError: _LOG.warning('Illegal status code %d for %s', packet.status, rpc) - - return None + return Status.UNKNOWN def _decode_payload(rpc: PendingRpc, packet): - if packet.type == PacketType.RESPONSE: - try: - return packets.decode_payload(packet, rpc.method.response_type) - except DecodeError as err: - _LOG.warning('Failed to decode %s response for %s: %s', - rpc.method.response_type.DESCRIPTOR.full_name, - rpc.method.full_name, err) + if packet.type == PacketType.SERVER_ERROR: + return None + + # Server streaming RPCs do not send a payload with their RESPONSE packet. + if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: + return None + + try: + return packets.decode_payload(packet, rpc.method.response_type) + except DecodeError as err: + _LOG.warning('Failed to decode %s response for %s: %s', + rpc.method.response_type.DESCRIPTOR.full_name, + rpc.method.full_name, err) return None @@ -326,6 +330,31 @@ f'services={[str(s) for s in self.services()]})') +def _update_for_backwards_compatibility(rpc: PendingRpc, + packet: RpcPacket) -> None: + """Adapts server streaming RPC packets to the updated protocol if needed.""" + # The protocol changes only affect server streaming RPCs. + if rpc.method.type is not Method.Type.SERVER_STREAMING: + return + + # SERVER_STREAM_END packets are deprecated. They are equivalent to a + # RESPONSE packet. + if packet.type == PacketType.DEPRECATED_SERVER_STREAM_END: + packet.type = PacketType.RESPONSE + return + + # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with + # a payload were used instead. If a non-zero payload is present, assume this + # RESPONSE is equivalent to a SERVER_STREAM packet. + # + # Note that the payload field is not 'optional', so an empty payload is + # equivalent to a payload that happens to encode to zero bytes. This would + # only affect server streaming RPCs on the old protocol that intentionally + # send empty payloads, which will not be an issue in practice. + if packet.type == PacketType.RESPONSE and packet.payload: + packet.type = PacketType.SERVER_STREAM + + class Client: """Sends requests and handles responses for a set of channels. @@ -424,15 +453,15 @@ _LOG.warning('%s', err) return Status.OK - status = _decode_status(rpc, packet) + _update_for_backwards_compatibility(rpc, packet) - if packet.type not in (PacketType.RESPONSE, - PacketType.SERVER_STREAM_END, + if packet.type not in (PacketType.RESPONSE, PacketType.SERVER_STREAM, PacketType.SERVER_ERROR): _LOG.error('%s: unexpected PacketType %s', rpc, packet.type) _LOG.debug('Packet:\n%s', packet) return Status.OK + status = _decode_status(rpc, packet) payload = _decode_payload(rpc, packet) try:
diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py index d36b7a4..1980b25 100644 --- a/pw_rpc/py/pw_rpc/packets.py +++ b/pw_rpc/py/pw_rpc/packets.py
@@ -19,7 +19,7 @@ from pw_rpc.internal import packet_pb2 -def decode(data: bytes): +def decode(data: bytes) -> packet_pb2.RpcPacket: packet = packet_pb2.RpcPacket() packet.MergeFromString(data) return packet @@ -73,5 +73,5 @@ method_id=method).SerializeToString() -def for_server(packet): +def for_server(packet: packet_pb2.RpcPacket) -> bool: return packet.type % 2 == 0
diff --git a/pw_rpc/py/tests/callback_client_test.py b/pw_rpc/py/tests/callback_client_test.py index 0373a07..4c6009f 100755 --- a/pw_rpc/py/tests/callback_client_test.py +++ b/pw_rpc/py/tests/callback_client_test.py
@@ -58,6 +58,10 @@ method_stub.method) +def _message_bytes(msg) -> bytes: + return msg if isinstance(msg, bytes) else msg.SerializeToString() + + class CallbackClientImplTest(unittest.TestCase): """Tests the callback_client as used within a pw_rpc Client.""" def setUp(self): @@ -77,7 +81,7 @@ channel_id: int, method=None, status: Status = Status.OK, - response=b'', + payload=b'', *, ids: Tuple[int, int] = None, process_status=Status.OK): @@ -88,32 +92,27 @@ assert ids is not None and method is None service_id, method_id = ids - if isinstance(response, bytes): - payload = response - else: - payload = response.SerializeToString() + self._next_packets.append((packet_pb2.RpcPacket( + type=packet_pb2.PacketType.RESPONSE, + channel_id=channel_id, + service_id=service_id, + method_id=method_id, + status=status.value, + payload=_message_bytes(payload)).SerializeToString(), + process_status)) - self._next_packets.append( - (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE, - channel_id=channel_id, - service_id=service_id, - method_id=method_id, - status=status.value, - payload=payload).SerializeToString(), - process_status)) - - def _enqueue_stream_end(self, - channel_id: int, - method, - status: Status = Status.OK, - process_status=Status.OK): - self._next_packets.append( - (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END, - channel_id=channel_id, - service_id=method.service.id, - method_id=method.id, - status=status.value).SerializeToString(), - process_status)) + def _enqueue_server_stream(self, + channel_id: int, + method, + response, + process_status=Status.OK): + self._next_packets.append((packet_pb2.RpcPacket( + type=packet_pb2.PacketType.SERVER_STREAM, + channel_id=channel_id, + service_id=method.service.id, + method_id=method.id, + payload=_message_bytes(response)).SerializeToString(), + process_status)) def _enqueue_error(self, channel_id: int, @@ -288,9 +287,9 @@ rep2 = method.response_type(payload='?') for _ in range(3): - self._enqueue_response(1, method, response=rep1) - self._enqueue_response(1, method, response=rep2) - self._enqueue_stream_end(1, method, Status.ABORTED) + self._enqueue_server_stream(1, method, rep1) + self._enqueue_server_stream(1, method, rep2) + self._enqueue_response(1, method, Status.ABORTED) self.assertEqual( [rep1, rep2], @@ -300,6 +299,35 @@ 4, self._sent_payload(method.request_type).magic_number) + def test_invoke_server_streaming_with_deprecated_packet_format(self): + method = self._service.SomeServerStreaming.method + + rep1 = method.response_type(payload='!!!') + rep2 = method.response_type(payload='?') + + for _ in range(3): + # The original packet format used RESPONSE packets for the server + # stream and a SERVER_STREAM_END packet as the last packet. These + # are converted to SERVER_STREAM packets followed by a RESPONSE. + self._enqueue_response(1, method, payload=rep1) + self._enqueue_response(1, method, payload=rep2) + + self._next_packets.append((packet_pb2.RpcPacket( + type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END, + channel_id=1, + service_id=method.service.id, + method_id=method.id, + status=Status.INVALID_ARGUMENT.value).SerializeToString(), + Status.OK)) + + call = self._service.SomeServerStreaming(magic_number=4) + self.assertEqual([rep1, rep2], list(call)) + self.assertIs(call.status, Status.INVALID_ARGUMENT) + + self.assertEqual( + 4, + self._sent_payload(method.request_type).magic_number) + def test_invoke_server_streaming_with_callbacks(self): method = self._service.SomeServerStreaming.method @@ -307,9 +335,9 @@ rep2 = method.response_type(payload='?') for _ in range(3): - self._enqueue_response(1, method, response=rep1) - self._enqueue_response(1, method, response=rep2) - self._enqueue_stream_end(1, method, Status.ABORTED) + self._enqueue_server_stream(1, method, rep1) + self._enqueue_server_stream(1, method, rep2) + self._enqueue_response(1, method, Status.ABORTED) callback = mock.Mock() self._service.SomeServerStreaming.invoke( @@ -330,7 +358,7 @@ stub = self._service.SomeServerStreaming resp = stub.method.response_type(payload='!!!') - self._enqueue_response(1, stub.method, response=resp) + self._enqueue_server_stream(1, stub.method, resp) callback = mock.Mock() call = stub.invoke(self._request(magic_number=3), callback) @@ -344,8 +372,8 @@ self.assertEqual(self._last_request.type, packet_pb2.PacketType.CANCEL) # Ensure the RPC can be called after being cancelled. - self._enqueue_response(1, stub.method, response=resp) - self._enqueue_stream_end(1, stub.method, Status.OK) + self._enqueue_server_stream(1, stub.method, resp) + self._enqueue_response(1, stub.method, Status.OK) call = stub.invoke(self._request(magic_number=3), callback, callback)
diff --git a/pw_rpc/raw/BUILD.gn b/pw_rpc/raw/BUILD.gn index 1f5d502..70c53cb 100644 --- a/pw_rpc/raw/BUILD.gn +++ b/pw_rpc/raw/BUILD.gn
@@ -45,6 +45,7 @@ public = [ "public/pw_rpc/raw_test_method_context.h" ] public_deps = [ ":method", + "..:test_utils", dir_pw_assert, dir_pw_containers, ]
diff --git a/pw_rpc/raw/public/pw_rpc/internal/raw_method.h b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h index 7d86ebe..4e63bee 100644 --- a/pw_rpc/raw/public/pw_rpc/internal/raw_method.h +++ b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h
@@ -118,6 +118,8 @@ ServerContext&, ConstByteSpan, ByteSpan)> { using Implementation = RawMethod; static constexpr MethodType kType = MethodType::kUnary; + static constexpr bool kServerStreaming = false; + static constexpr bool kClientStreaming = false; }; // MethodTraits specialization for a raw unary method. @@ -126,6 +128,8 @@ ServerContext&, ConstByteSpan, ByteSpan)> { using Implementation = RawMethod; static constexpr MethodType kType = MethodType::kUnary; + static constexpr bool kServerStreaming = false; + static constexpr bool kClientStreaming = false; using Service = T; }; @@ -134,6 +138,8 @@ struct MethodTraits<void (*)(ServerContext&, ConstByteSpan, RawServerWriter&)> { using Implementation = RawMethod; static constexpr MethodType kType = MethodType::kServerStreaming; + static constexpr bool kServerStreaming = true; + static constexpr bool kClientStreaming = false; }; // MethodTraits specialization for a raw server streaming method. @@ -142,6 +148,8 @@ ServerContext&, ConstByteSpan, RawServerWriter&)> { using Implementation = RawMethod; static constexpr MethodType kType = MethodType::kServerStreaming; + static constexpr bool kServerStreaming = true; + static constexpr bool kClientStreaming = false; using Service = T; };
diff --git a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h index b9fc660..6ff8738 100644 --- a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h +++ b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
@@ -16,7 +16,6 @@ #include <type_traits> #include "pw_assert/assert.h" -#include "pw_bytes/span.h" #include "pw_containers/vector.h" #include "pw_rpc/channel.h" #include "pw_rpc/internal/hash.h" @@ -24,6 +23,7 @@ #include "pw_rpc/internal/packet.h" #include "pw_rpc/internal/raw_method.h" #include "pw_rpc/internal/server.h" +#include "pw_rpc_private/fake_channel_output.h" namespace pw::rpc { @@ -79,7 +79,7 @@ ::pw::rpc::internal::Hash(#method), \ ##__VA_ARGS__> template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse = 4, size_t kOutputSizeBytes = 128> @@ -90,57 +90,50 @@ // A ChannelOutput implementation that stores the outgoing payloads and status. template <size_t kOutputSize> -class MessageOutput final : public ChannelOutput { +class MessageOutput final : public FakeChannelOutput { public: using ResponseBuffer = std::array<std::byte, kOutputSize>; MessageOutput(Vector<ByteSpan>& responses, Vector<ResponseBuffer>& buffers, - ByteSpan packet_buffer) - : ChannelOutput("internal::test::raw::MessageOutput"), + ByteSpan packet_buffer, + bool server_streaming) + : FakeChannelOutput(packet_buffer, server_streaming), responses_(responses), - buffers_(buffers), - packet_buffer_(packet_buffer) { - clear(); - } - - Status last_status() const { return last_status_; } - void set_last_status(Status status) { last_status_ = status; } - - size_t total_responses() const { return total_responses_; } - - bool stream_ended() const { return stream_ended_; } - - void clear() { - responses_.clear(); - buffers_.clear(); - total_responses_ = 0; - stream_ended_ = false; - last_status_ = Status::Unknown(); - } + buffers_(buffers) {} private: - ByteSpan AcquireBuffer() override { return packet_buffer_; } + void AppendResponse(ConstByteSpan response) override { + // If we run out of space, the back message is always the most recent. + buffers_.emplace_back(); + buffers_.back() = {}; + std::memcpy(&buffers_.back(), response.data(), response.size()); + responses_.emplace_back(); + responses_.back() = {buffers_.back().data(), response.size()}; + } - Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override; + void ClearResponses() override { + responses_.clear(); + buffers_.clear(); + } Vector<ByteSpan>& responses_; Vector<ResponseBuffer>& buffers_; - ByteSpan packet_buffer_; - size_t total_responses_; - bool stream_ended_; - Status last_status_; }; // Collects everything needed to invoke a particular RPC. template <typename Service, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> struct InvocationContext { template <typename... Args> InvocationContext(Args&&... args) - : output(responses, buffers, packet_buffer), + : output(responses, + buffers, + packet_buffer, + MethodTraits<decltype(kMethod)>::kServerStreaming), channel(Channel::Create<123>(&output)), server(std::span(&channel, 1)), service(std::forward<Args>(args)...), @@ -163,10 +156,14 @@ // Method invocation context for a unary RPC. Returns the status in call() and // provides the response through the response() method. -template <typename Service, auto method, uint32_t kMethodId, size_t kOutputSize> +template <typename Service, + auto kMethod, + uint32_t kMethodId, + size_t kOutputSize> class UnaryContext { private: - using Context = InvocationContext<Service, kMethodId, 1, kOutputSize>; + using Context = + InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize>; Context ctx_; public: @@ -183,7 +180,7 @@ ctx_.responses.emplace_back(); auto& response = ctx_.responses.back(); response = {ctx_.buffers.back().data(), ctx_.buffers.back().size()}; - auto sws = CallMethodImplFunction<method>(ctx_.call, request, response); + auto sws = CallMethodImplFunction<kMethod>(ctx_.call, request, response); response = response.first(sws.size()); return sws; } @@ -197,14 +194,14 @@ // Method invocation context for a server streaming RPC. template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> class ServerStreamingContext { private: using Context = - InvocationContext<Service, kMethodId, kMaxResponse, kOutputSize>; + InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize>; Context ctx_; public: @@ -217,7 +214,7 @@ void call(ConstByteSpan request) { ctx_.output.clear(); Responder server_writer(ctx_.call); - return CallMethodImplFunction<method>( + return CallMethodImplFunction<kMethod>( ctx_.call, request, static_cast<RawServerWriter&>(server_writer)); } @@ -239,7 +236,7 @@ size_t total_responses() const { return ctx_.output.total_responses(); } // True if the stream has terminated. - bool done() const { return ctx_.output.stream_ended(); } + bool done() const { return ctx_.output.done(); } // The status of the stream. Only valid if done() is true. Status status() const { @@ -251,74 +248,41 @@ // Alias to select the type of the context object to use based on which type of // RPC it is for. template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSize> using Context = std::tuple_element_t< - static_cast<size_t>(MethodTraits<decltype(method)>::kType), - std::tuple<UnaryContext<Service, method, kMethodId, kOutputSize>, + static_cast<size_t>(MethodTraits<decltype(kMethod)>::kType), + std::tuple<UnaryContext<Service, kMethod, kMethodId, kOutputSize>, ServerStreamingContext<Service, - method, + kMethod, kMethodId, kMaxResponse, kOutputSize> // TODO(hepler): Support client and bidi streaming >>; -template <size_t kOutputSize> -Status MessageOutput<kOutputSize>::SendAndReleaseBuffer( - std::span<const std::byte> buffer) { - PW_ASSERT(!stream_ended_); - PW_ASSERT(buffer.data() == packet_buffer_.data()); - - if (buffer.empty()) { - return OkStatus(); - } - - Result<internal::Packet> result = internal::Packet::FromBuffer(buffer); - PW_ASSERT(result.ok()); - - last_status_ = result.value().status(); - - switch (result.value().type()) { - case internal::PacketType::RESPONSE: { - // If we run out of space, the back message is always the most recent. - buffers_.emplace_back(); - buffers_.back() = {}; - auto response = result.value().payload(); - std::memcpy(&buffers_.back(), response.data(), response.size()); - responses_.emplace_back(); - responses_.back() = {buffers_.back().data(), response.size()}; - total_responses_ += 1; - break; - } - case internal::PacketType::SERVER_STREAM_END: - stream_ended_ = true; - break; - default: - pw_assert_HandleFailure(); // Unhandled PacketType - } - return OkStatus(); -} - } // namespace internal::test::raw template <typename Service, - auto method, + auto kMethod, uint32_t kMethodId, size_t kMaxResponse, size_t kOutputSizeBytes> class RawTestMethodContext : public internal::test::raw:: - Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes> { + Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> { public: // Forwards constructor arguments to the service class. template <typename... ServiceArgs> RawTestMethodContext(ServiceArgs&&... service_args) - : internal::test::raw:: - Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes>( - std::forward<ServiceArgs>(service_args)...) {} + : internal::test::raw::Context<Service, + kMethod, + kMethodId, + kMaxResponse, + kOutputSizeBytes>( + std::forward<ServiceArgs>(service_args)...) {} }; } // namespace pw::rpc
diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc index fd60781..a3df85d 100644 --- a/pw_rpc/raw/raw_method_test.cc +++ b/pw_rpc/raw/raw_method_test.cc
@@ -149,7 +149,7 @@ const RawMethod& method = std::get<0>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet(encoder.Encode().value())); + method.Invoke(context.get(), context.request(encoder.Encode().value())); EXPECT_EQ(last_request.integer, 456); EXPECT_EQ(last_request.status_code, 7u); @@ -174,7 +174,7 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet(encoder.Encode().value())); + method.Invoke(context.get(), context.request(encoder.Encode().value())); EXPECT_EQ(0u, context.output().packet_count()); EXPECT_EQ(777, last_request.integer); @@ -187,7 +187,7 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); auto buffer = last_writer.PayloadBuffer(); @@ -197,7 +197,7 @@ EXPECT_EQ(last_writer.Write(buffer.first(data.size())), OkStatus()); const internal::Packet& packet = context.output().sent_packet(); - EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE); + EXPECT_EQ(packet.type(), internal::PacketType::SERVER_STREAM); EXPECT_EQ(packet.channel_id(), context.channel_id()); EXPECT_EQ(packet.service_id(), context.service_id()); EXPECT_EQ(packet.method_id(), context.get().method().id()); @@ -209,13 +209,13 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>(); EXPECT_EQ(last_writer.Write(data), OkStatus()); const internal::Packet& packet = context.output().sent_packet(); - EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE); + EXPECT_EQ(packet.type(), internal::PacketType::SERVER_STREAM); EXPECT_EQ(packet.channel_id(), context.channel_id()); EXPECT_EQ(packet.service_id(), context.service_id()); EXPECT_EQ(packet.method_id(), context.get().method().id()); @@ -227,7 +227,7 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), last_writer.Finish()); constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>(); @@ -238,7 +238,7 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService, 16> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); constexpr auto data = bytes::Array<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16>(); @@ -250,7 +250,7 @@ const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method(); ServerContextForTest<FakeService> context(method); - method.Invoke(context.get(), context.packet({})); + method.Invoke(context.get(), context.request({})); { RawServerWriter writer = std::move(last_writer); @@ -261,7 +261,7 @@ auto output = context.output(); EXPECT_EQ(output.packet_count(), 1u); - EXPECT_EQ(output.sent_packet().type(), PacketType::SERVER_STREAM_END); + EXPECT_EQ(output.sent_packet().type(), PacketType::RESPONSE); } } // namespace
diff --git a/pw_rpc/raw/raw_method_union_test.cc b/pw_rpc/raw/raw_method_union_test.cc index f048366..aa30c7a 100644 --- a/pw_rpc/raw/raw_method_union_test.cc +++ b/pw_rpc/raw/raw_method_union_test.cc
@@ -107,7 +107,7 @@ const Method& method = std::get<1>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet(encoder.Encode().value())); + method.Invoke(context.get(), context.request(encoder.Encode().value())); EXPECT_EQ(last_request.integer, 456); EXPECT_EQ(last_request.status_code, 7u); @@ -133,7 +133,7 @@ std::get<2>(FakeGeneratedServiceImpl::kMethods).method(); ServerContextForTest<FakeGeneratedServiceImpl> context(method); - method.Invoke(context.get(), context.packet(encoder.Encode().value())); + method.Invoke(context.get(), context.request(encoder.Encode().value())); EXPECT_EQ(0u, context.output().packet_count()); EXPECT_EQ(777, last_request.integer);
diff --git a/pw_rpc/responder.cc b/pw_rpc/responder.cc index 1c20e8c..be9abb7 100644 --- a/pw_rpc/responder.cc +++ b/pw_rpc/responder.cc
@@ -20,6 +20,30 @@ #include "pw_rpc/internal/server.h" namespace pw::rpc::internal { +namespace { + +Packet ResponsePacket(const ServerCall& call, + uint32_t method_id, + Status status) { + return Packet(PacketType::RESPONSE, + call.channel().id(), + call.service().id(), + method_id, + {}, + status); +} + +Packet StreamPacket(const ServerCall& call, + uint32_t method_id, + std::span<const std::byte> payload) { + return Packet(PacketType::SERVER_STREAM, + call.channel().id(), + call.service().id(), + method_id, + payload); +} + +} // namespace Responder::Responder(ServerCall& call) : call_(call), state_(kOpen) { call_.server().RegisterResponder(*this); @@ -58,13 +82,8 @@ Close(); - // Send a control packet indicating that the stream (and RPC) has terminated. - return call_.channel().Send(Packet(PacketType::SERVER_STREAM_END, - call_.channel().id(), - call_.service().id(), - method().id(), - {}, - status)); + // Send a packet indicating that the RPC has terminated. + return call_.channel().Send(ResponsePacket(call_, method_id(), status)); } std::span<std::byte> Responder::AcquirePayloadBuffer() { @@ -75,12 +94,13 @@ response_ = call_.channel().AcquireBuffer(); } - return response_.payload(ResponsePacket()); + return response_.payload(StreamPacket(call_, method_id(), {})); } Status Responder::ReleasePayloadBuffer(std::span<const std::byte> payload) { PW_DCHECK(open()); - return call_.channel().Send(response_, ResponsePacket(payload)); + return call_.channel().Send(response_, + StreamPacket(call_, method_id(), payload)); } Status Responder::ReleasePayloadBuffer() { @@ -98,12 +118,4 @@ state_ = kClosed; } -Packet Responder::ResponsePacket(std::span<const std::byte> payload) const { - return Packet(PacketType::RESPONSE, - call_.channel().id(), - call_.service().id(), - method().id(), - payload); -} - } // namespace pw::rpc::internal
diff --git a/pw_rpc/responder_test.cc b/pw_rpc/responder_test.cc index 28b4910..8ae515d 100644 --- a/pw_rpc/responder_test.cc +++ b/pw_rpc/responder_test.cc
@@ -112,14 +112,14 @@ EXPECT_TRUE(writers.empty()); } -TEST(ServerWriter, Finish_SendsCancellationPacket) { +TEST(ServerWriter, Finish_SendsResponse) { ServerContextForTest<TestService> context(TestService::method.method()); FakeServerWriter writer(context.get()); EXPECT_EQ(OkStatus(), writer.Finish()); const Packet& packet = context.output().sent_packet(); - EXPECT_EQ(packet.type(), PacketType::SERVER_STREAM_END); + EXPECT_EQ(packet.type(), PacketType::RESPONSE); EXPECT_EQ(packet.channel_id(), context.channel_id()); EXPECT_EQ(packet.service_id(), context.service_id()); EXPECT_EQ(packet.method_id(), context.get().method().id()); @@ -166,7 +166,7 @@ ASSERT_EQ(OkStatus(), writer.Write(data)); byte encoded[64]; - auto result = context.packet(data).Encode(encoded); + auto result = context.server_stream(data).Encode(encoded); ASSERT_EQ(OkStatus(), result.status()); EXPECT_EQ(result.value().size(), context.output().sent_data().size());
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc index 1c9e813..c5e29b1 100644 --- a/pw_rpc/server.cc +++ b/pw_rpc/server.cc
@@ -150,7 +150,7 @@ channel.Send(Packet::ServerError(packet, Status::FailedPrecondition())); PW_LOG_WARN("Received CANCEL packet for method that is not pending"); } else { - writer->Finish(Status::Cancelled()); + writer->Close(); } }
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc index 8894b5a..f76168a 100644 --- a/pw_rpc/server_test.cc +++ b/pw_rpc/server_test.cc
@@ -247,18 +247,12 @@ EXPECT_FALSE(writer_.open()); } -TEST_F(MethodPending, ProcessPacket_Cancel_SendsStreamEndPacket) { +TEST_F(MethodPending, ProcessPacket_Cancel_SendsNoResponse) { EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_)); - const Packet& packet = output_.sent_packet(); - EXPECT_EQ(packet.type(), PacketType::SERVER_STREAM_END); - EXPECT_EQ(packet.channel_id(), 1u); - EXPECT_EQ(packet.service_id(), 42u); - EXPECT_EQ(packet.method_id(), 100u); - EXPECT_TRUE(packet.payload().empty()); - EXPECT_EQ(packet.status(), Status::Cancelled()); + EXPECT_EQ(output_.packet_count(), 0u); } TEST_F(MethodPending,