pw_rpc: Don't respond to error messages
Servers and client should never send errors in response to other errors.
Doing so could result in an infinite cycle of error packets.
Change-Id: If70f06fdcbd62b386a6103b532182314ec5face0
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/60447
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index ead04c1..f8c1b29 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -62,8 +62,11 @@
if (call == calls_.end()) {
PW_LOG_WARN("RPC client received a packet for a request it did not make");
- channel->Send(Packet::ClientError(packet, Status::FailedPrecondition()))
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
+ // Don't send responses to errors to avoid infinite error cycles.
+ if (packet.type() != PacketType::SERVER_ERROR) {
+ channel->Send(Packet::ClientError(packet, Status::FailedPrecondition()))
+ .IgnoreError();
+ }
return Status::NotFound();
}
diff --git a/pw_rpc/client_test.cc b/pw_rpc/client_test.cc
index ad3c1a8..0a61828 100644
--- a/pw_rpc/client_test.cc
+++ b/pw_rpc/client_test.cc
@@ -69,6 +69,16 @@
EXPECT_EQ(packet.status(), Status::FailedPrecondition());
}
+TEST(Client, ProcessPacket_ServerErrorOnUnregisteredCall_SendsNothing) {
+ internal::ClientContextForTest context;
+
+ EXPECT_EQ(
+ context.SendPacket(internal::PacketType::SERVER_ERROR, OkStatus(), {}),
+ Status::NotFound());
+
+ EXPECT_EQ(context.output().packet_count(), 0u);
+}
+
TEST(Client, ProcessPacket_ReturnsDataLossOnBadPacket) {
internal::ClientContextForTest context;
diff --git a/pw_rpc/py/pw_rpc/callback_client/call.py b/pw_rpc/py/pw_rpc/callback_client/call.py
index f758f01..abe1083 100644
--- a/pw_rpc/py/pw_rpc/callback_client/call.py
+++ b/pw_rpc/py/pw_rpc/callback_client/call.py
@@ -106,7 +106,7 @@
_LOG.info('%s completed: %s', self._rpc, status)
def _default_error(self, error: Status) -> None:
- _LOG.warning('%s termianted due to an error: %s', self._rpc, error)
+ _LOG.warning('%s terminated due to an error: %s', self._rpc, error)
@property
def method(self) -> Method:
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index 7f115a9..a70a1eb 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -557,5 +557,7 @@
def _send_client_error(client: ChannelClient, packet: RpcPacket,
error: Status) -> None:
- client.channel.output( # type: ignore
- packets.encode_client_error(packet, error))
+ # Never send responses to SERVER_ERRORs.
+ if packet.type != PacketType.SERVER_ERROR:
+ client.channel.output( # type: ignore
+ packets.encode_client_error(packet, error))
diff --git a/pw_rpc/py/tests/callback_client_test.py b/pw_rpc/py/tests/callback_client_test.py
index c0a2489..04f1388 100755
--- a/pw_rpc/py/tests/callback_client_test.py
+++ b/pw_rpc/py/tests/callback_client_test.py
@@ -117,16 +117,16 @@
def _enqueue_error(self,
channel_id: int,
+ service,
method,
status: Status,
process_status=Status.OK) -> None:
- self._next_packets.append(
- (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
- channel_id=channel_id,
- service_id=method.service.id,
- method_id=method.id,
- status=status.value).SerializeToString(),
- process_status))
+ self._next_packets.append((packet_pb2.RpcPacket(
+ type=packet_pb2.PacketType.SERVER_ERROR,
+ channel_id=channel_id,
+ service_id=service if isinstance(service, int) else service.id,
+ method_id=method if isinstance(method, int) else method.id,
+ status=status.value).SerializeToString(), process_status))
def _handle_packet(self, data: bytes) -> None:
if self.output_exception:
@@ -203,6 +203,29 @@
self.assertIs(Status.OK, status)
self.assertEqual('', response.payload)
+ def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
+ method = self._service.SomeUnary.method
+ service_id = method.service.id
+
+ # Unknown channel
+ self._enqueue_error(999,
+ service_id,
+ method,
+ Status.NOT_FOUND,
+ process_status=Status.NOT_FOUND)
+ # Bad service
+ self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT)
+ # Bad method
+ self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT)
+ # For RPC not pending
+ self._enqueue_error(1, service_id,
+ self._service.SomeBidiStreaming.method.id,
+ Status.NOT_FOUND)
+
+ self._process_enqueued_packets()
+
+ self.assertEqual(self.requests, [])
+
def test_exception_if_payload_fails_to_decode(self) -> None:
method = self._service.SomeUnary.method
@@ -315,7 +338,8 @@
def test_blocking_server_error(self) -> None:
for _ in range(3):
- self._enqueue_error(1, self.method, Status.NOT_FOUND)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.NOT_FOUND)
with self.assertRaises(callback_client.RpcError) as context:
self._service.SomeUnary(
@@ -577,7 +601,8 @@
requests = [self.method.request_type(magic_number=123)]
# Send after len(requests) and the client stream end packet.
- self._enqueue_error(1, self.method, Status.NOT_FOUND)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.NOT_FOUND)
with self.assertRaises(callback_client.RpcError) as context:
self.rpc(requests)
@@ -678,7 +703,8 @@
for _ in range(3):
stream = self._service.SomeClientStreaming.invoke()
- self._enqueue_error(1, self.method, Status.INVALID_ARGUMENT)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.INVALID_ARGUMENT)
stream.send(magic_number=2**32 - 1)
with self.assertRaises(callback_client.RpcError) as context:
@@ -691,7 +717,8 @@
stream = self._service.SomeClientStreaming.invoke()
# Error will be sent in response to the CLIENT_STREAM_END packet.
- self._enqueue_error(1, self.method, Status.INVALID_ARGUMENT)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.INVALID_ARGUMENT)
with self.assertRaises(callback_client.RpcError) as context:
stream.finish_and_wait()
@@ -719,7 +746,8 @@
self.assertEqual(result, call.finish_and_wait())
def test_nonblocking_finish_after_error(self) -> None:
- self._enqueue_error(1, self.method, Status.UNAVAILABLE)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.UNAVAILABLE)
call = self.rpc.invoke()
@@ -766,7 +794,8 @@
requests = [self.method.request_type(magic_number=123)]
# Send after len(requests) and the client stream end packet.
- self._enqueue_error(1, self.method, Status.NOT_FOUND)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.NOT_FOUND)
with self.assertRaises(callback_client.RpcError) as context:
self.rpc(requests)
@@ -861,7 +890,8 @@
self.assertFalse(stream.completed())
self.assertEqual([rep1], responses)
- self._enqueue_error(1, self.method, Status.OUT_OF_RANGE)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.OUT_OF_RANGE)
stream.send(magic_number=99999)
self.assertTrue(stream.completed())
@@ -879,7 +909,8 @@
stream = self._service.SomeBidiStreaming.invoke()
# Error will be sent in response to the CLIENT_STREAM_END packet.
- self._enqueue_error(1, self.method, Status.INVALID_ARGUMENT)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.INVALID_ARGUMENT)
with self.assertRaises(callback_client.RpcError) as context:
stream.finish_and_wait()
@@ -910,7 +941,8 @@
def test_nonblocking_finish_after_error(self) -> None:
reply = self.method.response_type(payload='!?')
self._enqueue_server_stream(1, self.method, reply)
- self._enqueue_error(1, self.method, Status.UNAVAILABLE)
+ self._enqueue_error(1, self.method.service, self.method,
+ Status.UNAVAILABLE)
call = self.rpc.invoke()
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index c7c83db..4be0881 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -85,21 +85,27 @@
// If the requested channel doesn't exist, try to dynamically assign one.
channel = AssignChannel(packet.channel_id(), interface);
if (channel == nullptr) {
- // If a channel can't be assigned, send a RESOURCE_EXHAUSTED error.
- internal::Channel temp_channel(packet.channel_id(), &interface);
- temp_channel
- .Send(Packet::ServerError(packet, Status::ResourceExhausted()))
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
- return OkStatus(); // OK since the packet was handled
+ // If a channel can't be assigned, send a RESOURCE_EXHAUSTED error. Never
+ // send responses to error messages, though, to avoid infinite cycles.
+ if (packet.type() != PacketType::CLIENT_ERROR) {
+ internal::Channel temp_channel(packet.channel_id(), &interface);
+ temp_channel
+ .Send(Packet::ServerError(packet, Status::ResourceExhausted()))
+ .IgnoreError();
+ }
+ return OkStatus(); // OK since the packet was handled
}
}
const auto [service, method] = FindMethod(packet);
if (method == nullptr) {
- channel->Send(Packet::ServerError(packet, Status::NotFound()))
- .IgnoreError(); // TODO(pwbug/387): Handle Status properly
- return OkStatus();
+ // Don't send responses to errors to avoid infinite error cycles.
+ if (packet.type() != PacketType::CLIENT_ERROR) {
+ channel->Send(Packet::ServerError(packet, Status::NotFound()))
+ .IgnoreError();
+ }
+ return OkStatus(); // OK since the packet was handled.
}
// Find an existing reader/writer for this RPC, if any.
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 42d3872..8f68de2 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -170,6 +170,14 @@
EXPECT_EQ(0u, service_.method(200).last_channel_id());
}
+TEST_F(BasicServer, ProcessPacket_ClientErrorWithInvalidMethod_NoResponse) {
+ EXPECT_EQ(OkStatus(),
+ server_.ProcessPacket(
+ EncodeRequest(PacketType::CLIENT_ERROR, 1, 42, 101), output_));
+
+ EXPECT_EQ(0u, output_.packet_count());
+}
+
TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
EXPECT_EQ(OkStatus(),
server_.ProcessPacket(EncodeRequest(PacketType::REQUEST, 1, 42, 27),
@@ -221,6 +229,18 @@
EXPECT_EQ(packet.method_id(), 27u);
}
+TEST_F(BasicServer, ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse) {
+ channels_[2] = Channel::Create<3>(&output_); // Occupy only available channel
+
+ EXPECT_EQ(
+ OkStatus(),
+ server_.ProcessPacket(
+ EncodeRequest(PacketType::CLIENT_ERROR, /*channel_id=*/99, 42, 27),
+ output_));
+
+ EXPECT_EQ(0u, output_.packet_count());
+}
+
TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsError) {
// Set up a fake ServerWriter representing an ongoing RPC.
EXPECT_EQ(OkStatus(),