pw_rpc: Locking improvements
- Expand locking annotations for internal::Call class members.
- Have the ChannelOutput::Send call release the RPC lock right before
sending the packet. This ensures the Call object's OutputBuffer
reference is cleared while the lock is held.
- Avoid holding the RPC lock while logging.
- Clear the encoding buffer in FakeChannelOutput to catch code that
uses stale references.
- Have FakeChannelOutput check that its buffers are properly acquired
and released.
- Mark PayloadBuffer() as [[nodiscard]].
Change-Id: I948e4ab531f45f424f2cd8cd3fa671bbb1509a85
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/76640
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
diff --git a/pw_rpc/call.cc b/pw_rpc/call.cc
index d047dba..13e3dea 100644
--- a/pw_rpc/call.cc
+++ b/pw_rpc/call.cc
@@ -64,9 +64,9 @@
}
void Call::MoveFrom(Call& other) {
- PW_DCHECK(!active());
+ PW_DCHECK(!active_locked());
- if (!other.active()) {
+ if (!other.active_locked()) {
return; // Nothing else to do; this call is already closed.
}
@@ -104,7 +104,7 @@
ConstByteSpan response,
Status status) {
rpc_lock().lock();
- if (!active()) {
+ if (!active_locked()) {
rpc_lock().unlock();
return Status::FailedPrecondition();
}
@@ -112,7 +112,7 @@
return SendPacket(type, response, status);
}
-ByteSpan Call::PayloadBuffer() {
+ByteSpan Call::PayloadBufferLocked() {
// Only allow having one active buffer at a time.
if (response_.empty()) {
response_ = channel().AcquireBuffer();
@@ -126,7 +126,7 @@
Status Call::Write(ConstByteSpan payload) {
rpc_lock().lock();
- if (!active()) {
+ if (!active_locked()) {
rpc_lock().unlock();
return Status::FailedPrecondition();
}
@@ -139,10 +139,10 @@
const Packet packet = MakePacket(type, payload, status);
if (!buffer().Contains(payload)) {
- ByteSpan buffer = PayloadBuffer();
+ ByteSpan buffer = PayloadBufferLocked();
if (payload.size() > buffer.size()) {
- ReleasePayloadBuffer();
+ ReleasePayloadBufferLocked();
rpc_lock().unlock();
return Status::OutOfRange();
}
@@ -150,17 +150,16 @@
std::copy_n(payload.data(), payload.size(), buffer.data());
}
- rpc_lock().unlock();
- return channel().Send(response_, packet);
+ return channel().SendBuffer(response_, packet);
}
-void Call::ReleasePayloadBuffer() {
- PW_DCHECK(active());
+void Call::ReleasePayloadBufferLocked() {
+ PW_DCHECK(active_locked());
channel().Release(response_);
}
void Call::Close() {
- if (active()) {
+ if (active_locked()) {
endpoint().UnregisterCall(*this);
}
diff --git a/pw_rpc/call_test.cc b/pw_rpc/call_test.cc
index a58fd07..25ecfcb 100644
--- a/pw_rpc/call_test.cc
+++ b/pw_rpc/call_test.cc
@@ -167,6 +167,7 @@
TEST(ServerWriter, DefaultConstructor_NoClientStream) {
FakeServerWriter writer;
+ LockGuard lock(rpc_lock());
EXPECT_FALSE(writer.as_server_call().has_client_stream());
EXPECT_FALSE(writer.as_server_call().client_stream_open());
}
@@ -175,6 +176,7 @@
ServerContextForTest<TestService> context(TestService::method.method());
FakeServerWriter writer(context.get());
+ LockGuard lock(rpc_lock());
EXPECT_FALSE(writer.as_server_call().has_client_stream());
EXPECT_FALSE(writer.as_server_call().client_stream_open());
}
@@ -182,6 +184,7 @@
TEST(ServerReader, DefaultConstructor_ClientStreamClosed) {
test::FakeServerReader reader;
EXPECT_FALSE(reader.as_server_call().active());
+ LockGuard lock(rpc_lock());
EXPECT_FALSE(reader.as_server_call().client_stream_open());
}
@@ -189,6 +192,7 @@
ServerContextForTest<TestService> context(TestService::method.method());
test::FakeServerReader reader(context.get());
+ LockGuard lock(rpc_lock());
EXPECT_TRUE(reader.as_server_call().has_client_stream());
EXPECT_TRUE(reader.as_server_call().client_stream_open());
}
@@ -198,11 +202,14 @@
test::FakeServerReader reader(context.get());
EXPECT_TRUE(reader.as_server_call().active());
+ rpc_lock().lock();
EXPECT_TRUE(reader.as_server_call().client_stream_open());
+ rpc_lock().unlock();
EXPECT_EQ(OkStatus(),
reader.as_server_call().CloseAndSendResponse(OkStatus()));
EXPECT_FALSE(reader.as_server_call().active());
+ LockGuard lock(rpc_lock());
EXPECT_FALSE(reader.as_server_call().client_stream_open());
}
@@ -211,11 +218,12 @@
test::FakeServerReader reader(context.get());
EXPECT_TRUE(reader.active());
- EXPECT_TRUE(reader.as_server_call().client_stream_open());
rpc_lock().lock();
+ EXPECT_TRUE(reader.as_server_call().client_stream_open());
reader.as_server_call().HandleClientStreamEnd();
EXPECT_TRUE(reader.active());
+ LockGuard lock(rpc_lock());
EXPECT_FALSE(reader.as_server_call().client_stream_open());
}
@@ -224,9 +232,12 @@
test::FakeServerReaderWriter reader_writer(context.get());
test::FakeServerReaderWriter destination;
+ rpc_lock().lock();
EXPECT_FALSE(destination.as_server_call().client_stream_open());
+ rpc_lock().unlock();
destination = std::move(reader_writer);
+ LockGuard lock(rpc_lock());
EXPECT_TRUE(destination.as_server_call().has_client_stream());
EXPECT_TRUE(destination.as_server_call().client_stream_open());
}
diff --git a/pw_rpc/channel.cc b/pw_rpc/channel.cc
index 3991834..14f165c 100644
--- a/pw_rpc/channel.cc
+++ b/pw_rpc/channel.cc
@@ -31,21 +31,26 @@
: std::span<byte>();
}
-Status Channel::Send(OutputBuffer& buffer, const internal::Packet& packet) {
+Status Channel::SendBuffer(OutputBuffer& buffer,
+ const internal::Packet& packet) {
Result encoded = packet.Encode(buffer.buffer_);
if (!encoded.ok()) {
+ const std::span released_buffer = buffer.buffer_;
+ buffer.buffer_ = {};
+ rpc_lock().unlock();
+
+ output().DiscardBuffer(released_buffer);
PW_LOG_ERROR(
"Failed to encode RPC packet type %u to channel %u buffer, status %u",
static_cast<unsigned>(packet.type()),
static_cast<unsigned>(id()),
encoded.status().code());
- output().DiscardBuffer(buffer.buffer_);
- buffer.buffer_ = {};
return Status::Internal();
}
buffer.buffer_ = {};
+ rpc_lock().unlock();
Status status = output().SendAndReleaseBuffer(encoded.value());
if (!status.ok()) {
diff --git a/pw_rpc/channel_test.cc b/pw_rpc/channel_test.cc
index ad30441..4f21b0e 100644
--- a/pw_rpc/channel_test.cc
+++ b/pw_rpc/channel_test.cc
@@ -74,7 +74,8 @@
Channel::OutputBuffer output_buffer = channel.AcquireBuffer();
EXPECT_TRUE(output_buffer.payload(kTestPacket).empty());
- EXPECT_EQ(Status::Internal(), channel.Send(output_buffer, kTestPacket));
+ rpc_lock().lock();
+ EXPECT_EQ(Status::Internal(), channel.SendBuffer(output_buffer, kTestPacket));
}
TEST(Channel, OutputBuffer_ExactFit) {
@@ -87,7 +88,8 @@
EXPECT_EQ(payload.size(), output.buffer().size() - kReservedSize);
EXPECT_EQ(output.buffer().data() + kReservedSize, payload.data());
- EXPECT_EQ(OkStatus(), channel.Send(output_buffer, kTestPacket));
+ rpc_lock().lock();
+ EXPECT_EQ(OkStatus(), channel.SendBuffer(output_buffer, kTestPacket));
}
TEST(Channel, OutputBuffer_PayloadDoesNotFit_ReportsError) {
@@ -98,6 +100,7 @@
byte data[1] = {};
packet.set_payload(data);
+ rpc_lock().lock();
EXPECT_EQ(Status::Internal(), channel.Send(packet));
}
@@ -111,7 +114,8 @@
EXPECT_EQ(payload.size(), output.buffer().size() - kReservedSize);
EXPECT_EQ(output.buffer().data() + kReservedSize, payload.data());
- EXPECT_EQ(OkStatus(), channel.Send(output_buffer, kTestPacket));
+ rpc_lock().lock();
+ EXPECT_EQ(OkStatus(), channel.SendBuffer(output_buffer, kTestPacket));
}
TEST(Channel, OutputBuffer_ReturnsStatusFromChannelOutputSend) {
@@ -121,7 +125,8 @@
Channel::OutputBuffer output_buffer = channel.AcquireBuffer();
output.set_send_status(Status::Aborted());
- EXPECT_EQ(Status::Aborted(), channel.Send(output_buffer, kTestPacket));
+ rpc_lock().lock();
+ EXPECT_EQ(Status::Aborted(), channel.SendBuffer(output_buffer, kTestPacket));
}
TEST(Channel, OutputBuffer_Contains_FalseWhenEmpty) {
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index 02c4fb9..f0f8181 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -44,20 +44,21 @@
internal::Channel* channel = GetInternalChannel(packet.channel_id());
if (channel == nullptr) {
- PW_LOG_WARN("RPC client received a packet for an unregistered channel");
internal::rpc_lock().unlock();
+ PW_LOG_WARN("RPC client received a packet for an unregistered channel");
return Status::Unavailable();
}
if (call == nullptr || call->id() != packet.call_id()) {
- internal::rpc_lock().unlock();
// The call for the packet does not exist. If the packet is a server stream
// message, notify the server so that it can kill the stream. Otherwise,
// silently drop the packet (as it would terminate the RPC anyway).
if (packet.type() == PacketType::SERVER_STREAM) {
- PW_LOG_WARN("RPC client received stream message for an unknown call");
channel->Send(Packet::ClientError(packet, Status::FailedPrecondition()))
.IgnoreError();
+ PW_LOG_WARN("RPC client received stream message for an unknown call");
+ } else {
+ internal::rpc_lock().unlock();
}
return OkStatus(); // OK since the packet was handled
}
@@ -80,8 +81,9 @@
if (call->has_server_stream()) {
call->HandlePayload(packet.payload());
} else {
- PW_LOG_DEBUG("Received SERVER_STREAM for RPC without a server stream");
call->HandleError(Status::InvalidArgument());
+ PW_LOG_DEBUG("Received SERVER_STREAM for RPC without a server stream");
+ internal::rpc_lock().lock();
// Report the error to the server so it can abort the RPC.
channel->Send(Packet::ClientError(packet, Status::InvalidArgument()))
.IgnoreError(); // Errors are logged in Channel::Send.
diff --git a/pw_rpc/endpoint.cc b/pw_rpc/endpoint.cc
index 0ffee0f..9edec4b 100644
--- a/pw_rpc/endpoint.cc
+++ b/pw_rpc/endpoint.cc
@@ -63,8 +63,8 @@
void Endpoint::RegisterCall(Call& call) {
LockGuard lock(rpc_lock());
- Call* const existing_call =
- FindCallById(call.channel_id(), call.service_id(), call.method_id());
+ Call* const existing_call = FindCallById(
+ call.channel_id_locked(), call.service_id(), call.method_id());
RegisterUniqueCall(call);
@@ -99,8 +99,8 @@
uint32_t service_id,
uint32_t method_id) {
for (Call& call : calls_) {
- if (channel_id == call.channel_id() && service_id == call.service_id() &&
- method_id == call.method_id()) {
+ if (channel_id == call.channel_id_locked() &&
+ service_id == call.service_id() && method_id == call.method_id()) {
return &call;
}
}
diff --git a/pw_rpc/fake_channel_output.cc b/pw_rpc/fake_channel_output.cc
index 10f0b52..8319002 100644
--- a/pw_rpc/fake_channel_output.cc
+++ b/pw_rpc/fake_channel_output.cc
@@ -25,6 +25,10 @@
namespace pw::rpc::internal::test {
+FakeChannelOutput::~FakeChannelOutput() {
+ PW_CHECK(!buffer_acquired_, "A FakeChannelOutput buffer was never released");
+}
+
void FakeChannelOutput::clear() {
payloads_.clear();
packets_.clear();
@@ -32,9 +36,21 @@
return_after_packet_count_ = -1;
}
-Status FakeChannelOutput::SendAndReleaseBuffer(
- std::span<const std::byte> buffer) {
+ByteSpan FakeChannelOutput::AcquireBuffer() {
+ PW_CHECK(!buffer_acquired_,
+ "Multiple FakeChannelOutput buffers were acquired");
+ buffer_acquired_ = true;
+
+ return encoding_buffer_;
+}
+
+Status FakeChannelOutput::HandlePacket(std::span<const std::byte> buffer) {
PW_CHECK_PTR_EQ(buffer.data(), encoding_buffer_.data());
+ PW_CHECK_UINT_LE(buffer.size(), encoding_buffer_.size());
+
+ PW_CHECK(buffer_acquired_,
+ "Releasing a FakeChannelOutput buffer that wasn't acquired");
+ buffer_acquired_ = false;
// If the buffer is empty, this is just releasing an unused buffer.
if (buffer.empty()) {
@@ -60,10 +76,11 @@
static_cast<unsigned>(packets_.size()));
packets_.push_back(*result);
+ Packet& packet = packets_.back();
- CopyPayloadToBuffer(packets_.back().payload());
+ CopyPayloadToBuffer(packet);
- switch (result.value().type()) {
+ switch (packet.type()) {
case PacketType::REQUEST:
return OkStatus();
case PacketType::RESPONSE:
@@ -72,15 +89,14 @@
case PacketType::CLIENT_STREAM:
return OkStatus();
case PacketType::DEPRECATED_SERVER_STREAM_END:
- PW_CRASH("Deprecated PacketType %d",
- static_cast<int>(result.value().type()));
+ PW_CRASH("Deprecated PacketType %d", static_cast<int>(packet.type()));
case PacketType::CLIENT_ERROR:
PW_LOG_WARN("FakeChannelOutput received client error: %s",
- result.value().status().str());
+ packet.status().str());
return OkStatus();
case PacketType::SERVER_ERROR:
PW_LOG_WARN("FakeChannelOutput received server error: %s",
- result.value().status().str());
+ packet.status().str());
return OkStatus();
case PacketType::DEPRECATED_CANCEL:
case PacketType::SERVER_STREAM:
@@ -90,7 +106,8 @@
PW_CRASH("Unhandled PacketType %d", static_cast<int>(result.value().type()));
}
-void FakeChannelOutput::CopyPayloadToBuffer(const ConstByteSpan& payload) {
+void FakeChannelOutput::CopyPayloadToBuffer(Packet& packet) {
+ const ConstByteSpan& payload = packet.payload();
if (payload.empty()) {
return;
}
@@ -105,7 +122,7 @@
const size_t start = payloads_.size();
payloads_.resize(payloads_.size() + payload.size());
std::memcpy(&payloads_[start], payload.data(), payload.size());
- packets_.back().set_payload(std::span(&payloads_[start], payload.size()));
+ packet.set_payload(std::span(&payloads_[start], payload.size()));
}
void FakeChannelOutput::LogPackets() const {
diff --git a/pw_rpc/fake_channel_output_test.cc b/pw_rpc/fake_channel_output_test.cc
index 4c74af4..31c91ba 100644
--- a/pw_rpc/fake_channel_output_test.cc
+++ b/pw_rpc/fake_channel_output_test.cc
@@ -33,6 +33,11 @@
constexpr std::array<std::byte, 3> kPayload = {
std::byte(1), std::byte(2), std::byte(3)};
+Status LockAndSend(Channel& channel, const Packet& packet) {
+ rpc_lock().lock();
+ return channel.Send(packet);
+}
+
class TestFakeChannelOutput final
: public FakeChannelOutputBuffer<kOutputSize, 9, 128> {
public:
@@ -57,7 +62,7 @@
kMethodId,
kCallId,
kPayload);
- ASSERT_EQ(channel.Send(server_stream_packet), OkStatus());
+ ASSERT_EQ(LockAndSend(channel, server_stream_packet), OkStatus());
ASSERT_EQ(output.last_response(type).size(), kPayload.size());
EXPECT_EQ(
std::memcmp(
@@ -83,22 +88,22 @@
kMethodId,
kCallId,
kPayload);
- EXPECT_EQ(channel.Send(response_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, response_packet), OkStatus());
EXPECT_EQ(output.total_payloads(type), 1u);
EXPECT_EQ(output.total_packets(), 1u);
EXPECT_TRUE(output.done());
// Multiple calls will return the same error status.
output.set_send_status(Status::Unknown());
- EXPECT_EQ(channel.Send(response_packet), Status::Unknown());
- EXPECT_EQ(channel.Send(response_packet), Status::Unknown());
- EXPECT_EQ(channel.Send(response_packet), Status::Unknown());
+ EXPECT_EQ(LockAndSend(channel, response_packet), Status::Unknown());
+ EXPECT_EQ(LockAndSend(channel, response_packet), Status::Unknown());
+ EXPECT_EQ(LockAndSend(channel, response_packet), Status::Unknown());
EXPECT_EQ(output.total_payloads(type), 1u);
EXPECT_EQ(output.total_packets(), 1u);
// Turn off error status behavior.
output.set_send_status(OkStatus());
- EXPECT_EQ(channel.Send(response_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, response_packet), OkStatus());
EXPECT_EQ(output.total_payloads(type), 2u);
EXPECT_EQ(output.total_packets(), 2u);
@@ -108,7 +113,7 @@
kMethodId,
kCallId,
kPayload);
- EXPECT_EQ(channel.Send(server_stream_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, server_stream_packet), OkStatus());
ASSERT_EQ(output.last_response(type).size(), kPayload.size());
EXPECT_EQ(
std::memcmp(
@@ -133,11 +138,11 @@
const int packet_count_fail = 4;
output.set_send_status(Status::Unknown(), packet_count_fail);
for (int i = 0; i < packet_count_fail; ++i) {
- EXPECT_EQ(channel.Send(response_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, response_packet), OkStatus());
}
- EXPECT_EQ(channel.Send(response_packet), Status::Unknown());
+ EXPECT_EQ(LockAndSend(channel, response_packet), Status::Unknown());
for (int i = 0; i < packet_count_fail; ++i) {
- EXPECT_EQ(channel.Send(response_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, response_packet), OkStatus());
}
const size_t total_response_packets =
@@ -147,7 +152,7 @@
// Turn off error status behavior.
output.set_send_status(OkStatus());
- EXPECT_EQ(channel.Send(response_packet), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, response_packet), OkStatus());
EXPECT_EQ(output.total_payloads(type), total_response_packets + 1);
EXPECT_EQ(output.total_packets(), total_response_packets + 1);
}
@@ -161,7 +166,7 @@
kMethodId,
kCallId,
kPayload);
- ASSERT_EQ(channel.Send(response_packet), OkStatus());
+ ASSERT_EQ(LockAndSend(channel, response_packet), OkStatus());
ASSERT_EQ(output.last_response(MethodType::kUnary).size(), kPayload.size());
EXPECT_EQ(std::memcmp(output.last_response(MethodType::kUnary).data(),
kPayload.data(),
@@ -174,7 +179,7 @@
output.clear();
const internal::Packet packet_empty_payload(
PacketType::RESPONSE, kChannelId, kServiceId, kMethodId, kCallId, {});
- EXPECT_EQ(channel.Send(packet_empty_payload), OkStatus());
+ EXPECT_EQ(LockAndSend(channel, packet_empty_payload), OkStatus());
EXPECT_EQ(output.last_response(MethodType::kUnary).size(), 0u);
EXPECT_EQ(output.total_payloads(MethodType::kUnary), 1u);
EXPECT_EQ(output.total_packets(), 1u);
@@ -186,7 +191,7 @@
kMethodId,
kCallId,
kPayload);
- ASSERT_EQ(channel.Send(server_stream_packet), OkStatus());
+ ASSERT_EQ(LockAndSend(channel, server_stream_packet), OkStatus());
ASSERT_EQ(output.total_payloads(MethodType::kServerStreaming), 1u);
ASSERT_EQ(output.last_response(MethodType::kServerStreaming).size(),
kPayload.size());
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index 9fa27ff..43af85f 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -60,9 +60,10 @@
return true;
}
+ rpc_lock().lock();
+ channel.Send(Packet::ServerError(request, Status::DataLoss())).IgnoreError();
PW_LOG_WARN("Nanopb failed to decode request payload from channel %u",
unsigned(channel.id()));
- channel.Send(Packet::ServerError(request, Status::DataLoss()));
return false;
}
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
index 83995a2..c0a7fe6 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
@@ -91,7 +91,7 @@
} else {
// TODO: it's silly to lock this just to call the callback
rpc_lock().lock();
- on_error(Status::DataLoss());
+ CallOnError(Status::DataLoss());
}
}
});
@@ -189,7 +189,7 @@
nanopb_on_next_(response_struct);
} else {
rpc_lock().lock();
- on_error(Status::DataLoss());
+ CallOnError(Status::DataLoss());
}
}
});
diff --git a/pw_rpc/public/pw_rpc/internal/call.h b/pw_rpc/public/pw_rpc/internal/call.h
index eec5870..1cdaab8 100644
--- a/pw_rpc/public/pw_rpc/internal/call.h
+++ b/pw_rpc/public/pw_rpc/internal/call.h
@@ -58,15 +58,31 @@
Call& operator=(Call&&) = delete;
// True if the Call is active and ready to send responses.
- [[nodiscard]] bool active() const { return rpc_state_ == kActive; }
+ [[nodiscard]] bool active() const PW_LOCKS_EXCLUDED(rpc_lock()) {
+ LockGuard lock(rpc_lock());
+ return active_locked();
+ }
- uint32_t id() const { return id_; }
+ [[nodiscard]] bool active_locked() const
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return rpc_state_ == kActive;
+ }
- uint32_t channel_id() const {
+ uint32_t id() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { return id_; }
+
+ uint32_t channel_id() const PW_LOCKS_EXCLUDED(rpc_lock()) {
+ LockGuard lock(rpc_lock());
+ return channel_id_locked();
+ }
+ uint32_t channel_id_locked() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
return channel_ == nullptr ? Channel::kUnassignedChannelId : channel().id();
}
- uint32_t service_id() const { return service_id_; }
- uint32_t method_id() const { return method_id_; }
+ uint32_t service_id() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return service_id_;
+ }
+ uint32_t method_id() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return method_id_;
+ }
// Closes the Call and sends a RESPONSE packet, if it is active. Returns the
// status from sending the packet, or FAILED_PRECONDITION if the Call is not
@@ -107,10 +123,11 @@
}
}
- // Precondition: rpc_lock() must be held.
+ // Handles an error condition for the call. This closes the call and calls the
+ // on_error callback, if set.
void HandleError(Status status) PW_UNLOCK_FUNCTION(rpc_lock()) {
Close();
- on_error(status);
+ CallOnError(status);
}
// Replaces this Call with a new Call object for the same RPC.
@@ -130,19 +147,29 @@
HandleError(Status::Cancelled());
}
- bool has_client_stream() const { return HasClientStream(type_); }
- bool has_server_stream() const { return HasServerStream(type_); }
+ bool has_client_stream() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return HasClientStream(type_);
+ }
+ bool has_server_stream() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return HasServerStream(type_);
+ }
- bool client_stream_open() const {
+ bool client_stream_open() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
return client_stream_state_ == kClientStreamActive;
}
// Acquires a buffer into which to write a payload or returns a previously
// acquired buffer. The Call MUST be active when this is called!
- ByteSpan PayloadBuffer();
+ [[nodiscard]] ByteSpan PayloadBuffer() PW_LOCKS_EXCLUDED(rpc_lock()) {
+ LockGuard lock(rpc_lock());
+ return PayloadBufferLocked();
+ }
// Releases the buffer without sending a packet.
- void ReleasePayloadBuffer();
+ void ReleasePayloadBuffer() PW_LOCKS_EXCLUDED(rpc_lock()) {
+ LockGuard lock(rpc_lock());
+ ReleasePayloadBufferLocked();
+ }
// Keep this public so the Nanopb implementation can set it from a helper
// function.
@@ -186,8 +213,12 @@
// This call must be in a closed state when this is called.
void MoveFrom(Call& other) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
- Endpoint& endpoint() const { return *endpoint_; }
- Channel& channel() const { return *channel_; }
+ Endpoint& endpoint() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return *endpoint_;
+ }
+ Channel& channel() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return *channel_;
+ }
void set_on_next_locked(Function<void(ConstByteSpan)>&& on_next) {
on_next_ = std::move(on_next);
@@ -199,7 +230,7 @@
// Calls the on_error callback without closing the RPC. This is used when the
// call has already completed.
- void on_error(Status error) PW_UNLOCK_FUNCTION(rpc_lock()) {
+ void CallOnError(Status error) PW_UNLOCK_FUNCTION(rpc_lock()) {
const bool invoke = on_error_ != nullptr;
rpc_lock().unlock();
@@ -208,7 +239,7 @@
}
}
- void MarkClientStreamCompleted() {
+ void MarkClientStreamCompleted() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
client_stream_state_ = kClientStreamInactive;
}
@@ -249,32 +280,42 @@
MethodType type,
CallType call_type);
+ [[nodiscard]] ByteSpan PayloadBufferLocked()
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+ void ReleasePayloadBufferLocked() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+
Packet MakePacket(PacketType type,
ConstByteSpan payload,
- Status status = OkStatus()) const {
- return Packet(
- type, channel_id(), service_id(), method_id(), id_, payload, status);
+ Status status = OkStatus()) const
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
+ return Packet(type,
+ channel_id_locked(),
+ service_id(),
+ method_id(),
+ id_,
+ payload,
+ status);
}
Status CloseAndSendFinalPacket(PacketType type,
ConstByteSpan response,
Status status) PW_LOCKS_EXCLUDED(rpc_lock());
- internal::Endpoint* endpoint_;
- internal::Channel* channel_;
- uint32_t id_;
- uint32_t service_id_;
- uint32_t method_id_;
+ internal::Endpoint* endpoint_ PW_GUARDED_BY(rpc_lock());
+ internal::Channel* channel_ PW_GUARDED_BY(rpc_lock());
+ uint32_t id_ PW_GUARDED_BY(rpc_lock());
+ uint32_t service_id_ PW_GUARDED_BY(rpc_lock());
+ uint32_t method_id_ PW_GUARDED_BY(rpc_lock());
- enum : bool { kInactive, kActive } rpc_state_;
- MethodType type_;
- CallType call_type_;
+ enum : bool { kInactive, kActive } rpc_state_ PW_GUARDED_BY(rpc_lock());
+ MethodType type_ PW_GUARDED_BY(rpc_lock());
+ CallType call_type_ PW_GUARDED_BY(rpc_lock());
enum : bool {
kClientStreamInactive,
kClientStreamActive,
- } client_stream_state_;
+ } client_stream_state_ PW_GUARDED_BY(rpc_lock());
- Channel::OutputBuffer response_;
+ Channel::OutputBuffer response_ PW_GUARDED_BY(rpc_lock());
// Called when the RPC is terminated due to an error.
Function<void(Status error)> on_error_;
diff --git a/pw_rpc/public/pw_rpc/internal/channel.h b/pw_rpc/public/pw_rpc/internal/channel.h
index 98e41fa..2767b72 100644
--- a/pw_rpc/public/pw_rpc/internal/channel.h
+++ b/pw_rpc/public/pw_rpc/internal/channel.h
@@ -17,6 +17,7 @@
#include "pw_assert/assert.h"
#include "pw_rpc/channel.h"
+#include "pw_rpc/internal/lock.h"
#include "pw_status/status.h"
namespace pw::rpc::internal {
@@ -74,12 +75,15 @@
return OutputBuffer(output().AcquireBuffer());
}
- Status Send(const internal::Packet& packet) {
+ // Sends an RPC packet. Acquires and uses a ChannelOutput buffer.
+ Status Send(const internal::Packet& packet) PW_UNLOCK_FUNCTION(rpc_lock()) {
OutputBuffer buffer = AcquireBuffer();
- return Send(buffer, packet);
+ return SendBuffer(buffer, packet);
}
- Status Send(OutputBuffer& buffer, const internal::Packet& packet);
+ // Sends an RPC packet using the provided output buffer.
+ Status SendBuffer(OutputBuffer& buffer, const internal::Packet& packet)
+ PW_UNLOCK_FUNCTION(rpc_lock());
void Release(OutputBuffer& buffer) {
output().DiscardBuffer(buffer.buffer_);
diff --git a/pw_rpc/public/pw_rpc/internal/fake_channel_output.h b/pw_rpc/public/pw_rpc/internal/fake_channel_output.h
index 2267472..e7f6dd2 100644
--- a/pw_rpc/public/pw_rpc/internal/fake_channel_output.h
+++ b/pw_rpc/public/pw_rpc/internal/fake_channel_output.h
@@ -36,6 +36,8 @@
FakeChannelOutput(const FakeChannelOutput&) = delete;
FakeChannelOutput(FakeChannelOutput&&) = delete;
+ ~FakeChannelOutput();
+
FakeChannelOutput& operator=(const FakeChannelOutput&) = delete;
FakeChannelOutput& operator=(FakeChannelOutput&&) = delete;
@@ -144,19 +146,27 @@
private:
friend class rpc::FakeServer;
- ByteSpan AcquireBuffer() final { return encoding_buffer_; }
+ ByteSpan AcquireBuffer() final;
// Processes buffer according to packet type and `return_after_packet_count_`
// value as follows:
// When positive, returns `send_status_` once,
// When equals 0, returns `send_status_` in all future calls,
// When negative, ignores `send_status_` processes buffer.
- Status SendAndReleaseBuffer(ConstByteSpan buffer) final;
+ Status SendAndReleaseBuffer(ConstByteSpan buffer) final {
+ const Status status = HandlePacket(buffer);
+ // Clear the encoding buffer to catch code that uses this buffer after
+ // releasing it.
+ std::fill(encoding_buffer_.begin(), encoding_buffer_.end(), std::byte{0});
+ return status;
+ }
- void CopyPayloadToBuffer(const ConstByteSpan& payload);
+ Status HandlePacket(ConstByteSpan buffer);
+ void CopyPayloadToBuffer(Packet& packet);
int return_after_packet_count_ = -1;
unsigned total_response_packets_ = 0;
+ bool buffer_acquired_ = false;
Vector<Packet>& packets_;
Vector<std::byte>& payloads_;
@@ -171,14 +181,15 @@
class FakeChannelOutputBuffer : public FakeChannelOutput {
protected:
constexpr FakeChannelOutputBuffer()
- : FakeChannelOutput(packets_, payloads_, encoding_buffer),
- encoding_buffer{},
- payloads_ {}
+ : FakeChannelOutput(
+ packets_array_, payloads_array_, encoding_buffer_array_),
+ encoding_buffer_array_{},
+ payloads_array_ {}
{}
- std::byte encoding_buffer[kOutputSizeBytes];
- Vector<std::byte, kPayloadsBufferSizeBytes> payloads_;
- Vector<Packet, kMaxPackets> packets_;
+ std::byte encoding_buffer_array_[kOutputSizeBytes];
+ Vector<std::byte, kPayloadsBufferSizeBytes> payloads_array_;
+ Vector<Packet, kMaxPackets> packets_array_;
};
} // namespace internal::test
diff --git a/pw_rpc/public/pw_rpc/internal/open_call.h b/pw_rpc/public/pw_rpc/internal/open_call.h
index 1c42ebe..64f29ce 100644
--- a/pw_rpc/public/pw_rpc/internal/open_call.h
+++ b/pw_rpc/public/pw_rpc/internal/open_call.h
@@ -50,7 +50,7 @@
"streaming RPCs.");
}
- // TODO(hepler): Update the CallContext to store the ID instead, and lookup
+ // TODO(pwbug/505): Update the CallContext to store the ID instead, and lookup
// the channel by ID.
rpc::Channel* channel = server.GetChannel(channel_id);
PW_ASSERT(channel != nullptr);
diff --git a/pw_rpc/raw/server_reader_writer_test.cc b/pw_rpc/raw/server_reader_writer_test.cc
index e2e1412..6be3174 100644
--- a/pw_rpc/raw/server_reader_writer_test.cc
+++ b/pw_rpc/raw/server_reader_writer_test.cc
@@ -146,6 +146,32 @@
EXPECT_EQ(completions.back(), OkStatus());
}
+TEST(RawUnaryResponder, Move_DifferentActiveCalls_ClosesFirstOnly) {
+ ReaderWriterTestContext ctx;
+ RawUnaryResponder active_call =
+ RawUnaryResponder::Open<TestService::TestUnaryRpc>(
+ ctx.server, ctx.channel.id(), ctx.service);
+
+ std::span buffer = active_call.PayloadBuffer();
+ ASSERT_FALSE(buffer.empty());
+
+ RawUnaryResponder new_active_call =
+ RawUnaryResponder::Open<TestService::TestAnotherUnaryRpc>(
+ ctx.server, ctx.channel.id(), ctx.service);
+
+ EXPECT_TRUE(active_call.active());
+ EXPECT_TRUE(new_active_call.active());
+
+ active_call = std::move(new_active_call);
+
+ const auto completions = ctx.output.completions<TestService::TestUnaryRpc>();
+ ASSERT_EQ(completions.size(), 1u);
+ EXPECT_EQ(completions.back(), OkStatus());
+
+ EXPECT_TRUE(
+ ctx.output.completions<TestService::TestAnotherUnaryRpc>().empty());
+}
+
TEST(RawUnaryResponder, ReplaceActiveCall_DoesNotFinishCall) {
ReaderWriterTestContext ctx;
RawUnaryResponder active_call =
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index d5f9d9b..8da8e9e 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -64,7 +64,6 @@
// If the requested channel doesn't exist, try to dynamically assign one.
channel = AssignChannel(packet.channel_id(), *interface);
if (channel == nullptr) {
- internal::rpc_lock().unlock();
// If a channel can't be assigned, send a RESOURCE_EXHAUSTED error. Never
// send responses to error messages, though, to avoid infinite cycles.
if (packet.type() != PacketType::CLIENT_ERROR) {
@@ -72,6 +71,8 @@
temp_channel
.Send(Packet::ServerError(packet, Status::ResourceExhausted()))
.IgnoreError();
+ } else {
+ internal::rpc_lock().unlock();
}
return OkStatus(); // OK since the packet was handled
}
@@ -80,11 +81,12 @@
const auto [service, method] = FindMethod(packet);
if (method == nullptr) {
- internal::rpc_lock().unlock();
// Don't send responses to errors to avoid infinite error cycles.
if (packet.type() != PacketType::CLIENT_ERROR) {
channel->Send(Packet::ServerError(packet, Status::NotFound()))
.IgnoreError();
+ } else {
+ internal::rpc_lock().unlock();
}
return OkStatus(); // OK since the packet was handled.
}
@@ -140,26 +142,23 @@
internal::Channel& channel,
internal::ServerCall* call) const {
if (call == nullptr || call->id() != packet.call_id()) {
- internal::rpc_lock().unlock();
+ channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
+ .IgnoreError(); // Errors are logged in Channel::Send.
PW_LOG_DEBUG(
"Received client stream packet for %u:%08x/%08x, which is not pending",
static_cast<unsigned>(packet.channel_id()),
static_cast<unsigned>(packet.service_id()),
static_cast<unsigned>(packet.method_id()));
- channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
- .IgnoreError(); // Errors are logged in Channel::Send.
return;
}
if (!call->has_client_stream()) {
- internal::rpc_lock().unlock();
channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
.IgnoreError(); // Errors are logged in Channel::Send.
return;
}
if (!call->client_stream_open()) {
- internal::rpc_lock().unlock();
channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
.IgnoreError(); // Errors are logged in Channel::Send.
return;
diff --git a/pw_rpc/server_call.cc b/pw_rpc/server_call.cc
index 8975c77..2109112 100644
--- a/pw_rpc/server_call.cc
+++ b/pw_rpc/server_call.cc
@@ -18,7 +18,7 @@
void ServerCall::MoveServerCallFrom(ServerCall& other) {
// If this call is active, finish it first.
- if (active()) {
+ if (active_locked()) {
Close();
SendPacket(PacketType::RESPONSE,
{},