pw_rpc: Ensure locks are held during Call::Call
Thread safety analysis doesn't check constructors or destructors.
This CL introduces `Locked...` arguments to constructors to ensure
that they are called under the proper conditions. The CL fixes a number
of callsites which were previously not acquiring the lock, and
rearranges a number of callsites in order to prevent deadlocks.
Change-Id: Ife28f1131ab8620014af3c60e691fba89b476e82
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/110271
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
Pigweed-Auto-Submit: Taylor Cramer <cramertj@google.com>
diff --git a/pw_rpc/call.cc b/pw_rpc/call.cc
index e4afcb5..e323cfc 100644
--- a/pw_rpc/call.cc
+++ b/pw_rpc/call.cc
@@ -22,8 +22,18 @@
namespace pw::rpc::internal {
+// Creates an active server-side Call.
+Call::Call(const LockedCallContext& context, MethodType type)
+ : Call(context.server().ClaimLocked(),
+ context.call_id(),
+ context.channel_id(),
+ UnwrapServiceId(context.service().service_id()),
+ context.method().id(),
+ type,
+ kServerCall) {}
+
// Creates an active client-side call, assigning it a new ID.
-Call::Call(Endpoint& client,
+Call::Call(LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
@@ -36,7 +46,7 @@
type,
kClientCall) {}
-Call::Call(Endpoint& endpoint_ref,
+Call::Call(LockedEndpoint& endpoint_ref,
uint32_t call_id,
uint32_t channel_id,
uint32_t service_id,
diff --git a/pw_rpc/call_test.cc b/pw_rpc/call_test.cc
index 208342d..18f337e7 100644
--- a/pw_rpc/call_test.cc
+++ b/pw_rpc/call_test.cc
@@ -39,209 +39,215 @@
constexpr Packet kPacket(PacketType::REQUEST, 99, 16, 8);
-using pw::rpc::internal::test::FakeServerWriter;
-using std::byte;
+using ::pw::rpc::internal::test::FakeServerReader;
+using ::pw::rpc::internal::test::FakeServerReaderWriter;
+using ::pw::rpc::internal::test::FakeServerWriter;
+using ::std::byte;
+using ::testing::Test;
-TEST(ServerWriter, ConstructWithContext_StartsOpen) {
- ServerContextForTest<TestService> context(TestService::method.method());
+class ServerWriterTest : public Test {
+ public:
+ ServerWriterTest() : context_(TestService::method.method()) {
+ rpc_lock().lock();
+ FakeServerWriter writer_temp(context_.get().ClaimLocked());
+ rpc_lock().unlock();
+ writer_ = std::move(writer_temp);
+ }
- FakeServerWriter writer(context.get());
+ ServerContextForTest<TestService> context_;
+ FakeServerWriter writer_;
+};
- EXPECT_TRUE(writer.active());
+TEST_F(ServerWriterTest, ConstructWithContext_StartsOpen) {
+ EXPECT_TRUE(writer_.active());
}
-TEST(ServerWriter, Move_ClosesOriginal) {
- ServerContextForTest<TestService> context(TestService::method.method());
-
- FakeServerWriter moved(context.get());
- FakeServerWriter writer(std::move(moved));
+TEST_F(ServerWriterTest, Move_ClosesOriginal) {
+ FakeServerWriter moved(std::move(writer_));
#ifndef __clang_analyzer__
- EXPECT_FALSE(moved.active());
+ EXPECT_FALSE(writer_.active());
#endif // ignore use-after-move
- EXPECT_TRUE(writer.active());
+ EXPECT_TRUE(moved.active());
}
-TEST(ServerWriter, DefaultConstruct_Closed) {
+TEST_F(ServerWriterTest, DefaultConstruct_Closed) {
FakeServerWriter writer;
-
EXPECT_FALSE(writer.active());
}
-TEST(ServerWriter, Construct_RegistersWithServer) PW_NO_LOCK_SAFETY_ANALYSIS {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
- Call* call = context.server().FindCall(kPacket);
+TEST_F(ServerWriterTest, Construct_RegistersWithServer) {
+ std::lock_guard lock(rpc_lock());
+ Call* call = context_.server().FindCall(kPacket);
ASSERT_NE(call, nullptr);
- EXPECT_EQ(static_cast<void*>(call), static_cast<void*>(&writer));
+ EXPECT_EQ(static_cast<void*>(call), static_cast<void*>(&writer_));
}
-TEST(ServerWriter, Destruct_RemovesFromServer) PW_NO_LOCK_SAFETY_ANALYSIS {
+TEST_F(ServerWriterTest, Destruct_RemovesFromServer) {
ServerContextForTest<TestService> context(TestService::method.method());
- { FakeServerWriter writer(context.get()); }
+ {
+ // Note `lock_guard` cannot be used here, because while the constructor
+ // of `FakeServerWriter` requires the lock be held, the destructor acquires
+ // it!
+ rpc_lock().lock();
+ FakeServerWriter writer(context.get().ClaimLocked());
+ rpc_lock().unlock();
+ }
+ std::lock_guard lock(rpc_lock());
EXPECT_EQ(context.server().FindCall(kPacket), nullptr);
}
-TEST(ServerWriter, Finish_RemovesFromServer) PW_NO_LOCK_SAFETY_ANALYSIS {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
- EXPECT_EQ(OkStatus(), writer.Finish());
-
- EXPECT_EQ(context.server().FindCall(kPacket), nullptr);
+TEST_F(ServerWriterTest, Finish_RemovesFromServer) {
+ EXPECT_EQ(OkStatus(), writer_.Finish());
+ std::lock_guard lock(rpc_lock());
+ EXPECT_EQ(context_.server().FindCall(kPacket), nullptr);
}
-TEST(ServerWriter, Finish_SendsResponse) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
+TEST_F(ServerWriterTest, Finish_SendsResponse) {
+ EXPECT_EQ(OkStatus(), writer_.Finish());
- EXPECT_EQ(OkStatus(), writer.Finish());
-
- ASSERT_EQ(context.output().total_packets(), 1u);
- const Packet& packet = context.output().last_packet();
+ ASSERT_EQ(context_.output().total_packets(), 1u);
+ const Packet& packet = context_.output().last_packet();
EXPECT_EQ(packet.type(), PacketType::RESPONSE);
- EXPECT_EQ(packet.channel_id(), context.channel_id());
- EXPECT_EQ(packet.service_id(), context.service_id());
- EXPECT_EQ(packet.method_id(), context.get().method().id());
+ EXPECT_EQ(packet.channel_id(), context_.channel_id());
+ EXPECT_EQ(packet.service_id(), context_.service_id());
+ EXPECT_EQ(packet.method_id(), context_.get().method().id());
EXPECT_TRUE(packet.payload().empty());
EXPECT_EQ(packet.status(), OkStatus());
}
-TEST(ServerWriter, Finish_ReturnsStatusFromChannelSend) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
- context.output().set_send_status(Status::Unauthenticated());
+TEST_F(ServerWriterTest, Finish_ReturnsStatusFromChannelSend) {
+ context_.output().set_send_status(Status::Unauthenticated());
// All non-OK statuses are remapped to UNKNOWN.
- EXPECT_EQ(Status::Unknown(), writer.Finish());
+ EXPECT_EQ(Status::Unknown(), writer_.Finish());
}
-TEST(ServerWriter, Finish) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
- ASSERT_TRUE(writer.active());
- EXPECT_EQ(OkStatus(), writer.Finish());
- EXPECT_FALSE(writer.active());
- EXPECT_EQ(Status::FailedPrecondition(), writer.Finish());
+TEST_F(ServerWriterTest, Finish) {
+ ASSERT_TRUE(writer_.active());
+ EXPECT_EQ(OkStatus(), writer_.Finish());
+ EXPECT_FALSE(writer_.active());
+ EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
}
-TEST(ServerWriter, Open_SendsPacketWithPayload) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
+TEST_F(ServerWriterTest, Open_SendsPacketWithPayload) {
constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
- ASSERT_EQ(OkStatus(), writer.Write(data));
+ ASSERT_EQ(OkStatus(), writer_.Write(data));
byte encoded[64];
- auto result = context.server_stream(data).Encode(encoded);
+ auto result = context_.server_stream(data).Encode(encoded);
ASSERT_EQ(OkStatus(), result.status());
- ConstByteSpan payload = context.output().last_packet().payload();
+ ConstByteSpan payload = context_.output().last_packet().payload();
EXPECT_EQ(sizeof(data), payload.size());
EXPECT_EQ(0, std::memcmp(data, payload.data(), sizeof(data)));
}
-TEST(ServerWriter, Closed_IgnoresFinish) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
- EXPECT_EQ(OkStatus(), writer.Finish());
- EXPECT_EQ(Status::FailedPrecondition(), writer.Finish());
+TEST_F(ServerWriterTest, Closed_IgnoresFinish) {
+ EXPECT_EQ(OkStatus(), writer_.Finish());
+ EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
}
-TEST(ServerWriter, DefaultConstructor_NoClientStream) {
+TEST_F(ServerWriterTest, DefaultConstructor_NoClientStream) {
FakeServerWriter writer;
LockGuard lock(rpc_lock());
EXPECT_FALSE(writer.as_server_call().has_client_stream());
EXPECT_FALSE(writer.as_server_call().client_stream_open());
}
-TEST(ServerWriter, Open_NoClientStream) {
- ServerContextForTest<TestService> context(TestService::method.method());
- FakeServerWriter writer(context.get());
-
+TEST_F(ServerWriterTest, Open_NoClientStream) {
LockGuard lock(rpc_lock());
- EXPECT_FALSE(writer.as_server_call().has_client_stream());
- EXPECT_FALSE(writer.as_server_call().client_stream_open());
+ EXPECT_FALSE(writer_.as_server_call().has_client_stream());
+ EXPECT_FALSE(writer_.as_server_call().client_stream_open());
}
-TEST(ServerReader, DefaultConstructor_ClientStreamClosed) {
- test::FakeServerReader reader;
+class ServerReaderTest : public Test {
+ public:
+ ServerReaderTest() : context_(TestService::method.method()) {
+ rpc_lock().lock();
+ FakeServerReader reader_temp(context_.get().ClaimLocked());
+ rpc_lock().unlock();
+ reader_ = std::move(reader_temp);
+ }
+
+ ServerContextForTest<TestService> context_;
+ FakeServerReader reader_;
+};
+
+TEST_F(ServerReaderTest, DefaultConstructor_ClientStreamClosed) {
+ FakeServerReader reader;
EXPECT_FALSE(reader.as_server_call().active());
LockGuard lock(rpc_lock());
EXPECT_FALSE(reader.as_server_call().client_stream_open());
}
-TEST(ServerReader, Open_ClientStreamStartsOpen) {
- ServerContextForTest<TestService> context(TestService::method.method());
- test::FakeServerReader reader(context.get());
-
- LockGuard lock(rpc_lock());
- EXPECT_TRUE(reader.as_server_call().has_client_stream());
- EXPECT_TRUE(reader.as_server_call().client_stream_open());
+TEST_F(ServerReaderTest, Open_ClientStreamStartsOpen) {
+ std::lock_guard lock(rpc_lock());
+ EXPECT_TRUE(reader_.as_server_call().has_client_stream());
+ EXPECT_TRUE(reader_.as_server_call().client_stream_open());
}
-TEST(ServerReader, Close_ClosesClientStream) {
- ServerContextForTest<TestService> context(TestService::method.method());
- test::FakeServerReader reader(context.get());
-
- EXPECT_TRUE(reader.as_server_call().active());
+TEST_F(ServerReaderTest, Close_ClosesClientStream) {
+ EXPECT_TRUE(reader_.as_server_call().active());
rpc_lock().lock();
- EXPECT_TRUE(reader.as_server_call().client_stream_open());
+ EXPECT_TRUE(reader_.as_server_call().client_stream_open());
rpc_lock().unlock();
EXPECT_EQ(OkStatus(),
- reader.as_server_call().CloseAndSendResponse(OkStatus()));
+ reader_.as_server_call().CloseAndSendResponse(OkStatus()));
- EXPECT_FALSE(reader.as_server_call().active());
+ EXPECT_FALSE(reader_.as_server_call().active());
LockGuard lock(rpc_lock());
- EXPECT_FALSE(reader.as_server_call().client_stream_open());
+ EXPECT_FALSE(reader_.as_server_call().client_stream_open());
}
-TEST(ServerReader, EndClientStream_OnlyClosesClientStream) {
- ServerContextForTest<TestService> context(TestService::method.method());
- test::FakeServerReader reader(context.get());
-
- EXPECT_TRUE(reader.active());
+TEST_F(ServerReaderTest, EndClientStream_OnlyClosesClientStream) {
+ EXPECT_TRUE(reader_.active());
rpc_lock().lock();
- EXPECT_TRUE(reader.as_server_call().client_stream_open());
- reader.as_server_call().HandleClientStreamEnd();
+ EXPECT_TRUE(reader_.as_server_call().client_stream_open());
+ reader_.as_server_call().HandleClientStreamEnd();
- EXPECT_TRUE(reader.active());
+ EXPECT_TRUE(reader_.active());
LockGuard lock(rpc_lock());
- EXPECT_FALSE(reader.as_server_call().client_stream_open());
+ EXPECT_FALSE(reader_.as_server_call().client_stream_open());
}
-TEST(ServerReaderWriter, Move_MaintainsClientStream) {
- ServerContextForTest<TestService> context(TestService::method.method());
- test::FakeServerReaderWriter reader_writer(context.get());
- test::FakeServerReaderWriter destination;
+class ServerReaderWriterTest : public Test {
+ public:
+ ServerReaderWriterTest() : context_(TestService::method.method()) {
+ rpc_lock().lock();
+ FakeServerReaderWriter reader_writer_temp(context_.get().ClaimLocked());
+ rpc_lock().unlock();
+ reader_writer_ = std::move(reader_writer_temp);
+ }
+
+ ServerContextForTest<TestService> context_;
+ FakeServerReaderWriter reader_writer_;
+};
+
+TEST_F(ServerReaderWriterTest, Move_MaintainsClientStream) {
+ FakeServerReaderWriter destination;
rpc_lock().lock();
EXPECT_FALSE(destination.as_server_call().client_stream_open());
rpc_lock().unlock();
- destination = std::move(reader_writer);
+ destination = std::move(reader_writer_);
LockGuard lock(rpc_lock());
EXPECT_TRUE(destination.as_server_call().has_client_stream());
EXPECT_TRUE(destination.as_server_call().client_stream_open());
}
-TEST(ServerReaderWriter, Move_MovesCallbacks) {
- ServerContextForTest<TestService> context(TestService::method.method());
- test::FakeServerReaderWriter reader_writer(context.get());
-
+TEST_F(ServerReaderWriterTest, Move_MovesCallbacks) {
int calls = 0;
- reader_writer.set_on_error([&calls](Status) { calls += 1; });
- reader_writer.set_on_next([&calls](ConstByteSpan) { calls += 1; });
+ reader_writer_.set_on_error([&calls](Status) { calls += 1; });
+ reader_writer_.set_on_next([&calls](ConstByteSpan) { calls += 1; });
#if PW_RPC_CLIENT_STREAM_END_CALLBACK
reader_writer.set_on_client_stream_end([&calls]() { calls += 1; });
#endif // PW_RPC_CLIENT_STREAM_END_CALLBACK
- test::FakeServerReaderWriter destination(std::move(reader_writer));
+ FakeServerReaderWriter destination(std::move(reader_writer_));
rpc_lock().lock();
destination.as_server_call().HandlePayload({});
rpc_lock().lock();
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index ed38f39..acf35fb 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -36,7 +36,7 @@
return;
}
- NanopbServerCall responder(context, MethodType::kUnary);
+ NanopbServerCall responder(context.ClaimLocked(), MethodType::kUnary);
rpc_lock().unlock();
const Status status = function_.synchronous_unary(
context.service(), request_struct, response_struct);
@@ -52,7 +52,7 @@
return;
}
- NanopbServerCall server_writer(context, type);
+ NanopbServerCall server_writer(context.ClaimLocked(), type);
rpc_lock().unlock();
function_.unary_request(context.service(), request_struct, server_writer);
}
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
index 61f5a73..b447988 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
@@ -39,7 +39,8 @@
Function<void(Status)>&& on_error,
const Request&... request) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id, serde);
+ CallType call(
+ client.ClaimLocked(), channel_id, service_id, method_id, serde);
call.set_on_completed_locked(std::move(on_completed));
call.set_on_error_locked(std::move(on_error));
@@ -55,12 +56,13 @@
protected:
constexpr NanopbUnaryResponseClientCall() = default;
- NanopbUnaryResponseClientCall(internal::Endpoint& client,
+ NanopbUnaryResponseClientCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type,
const NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: UnaryResponseClientCall(
client, channel_id, service_id, method_id, type),
serde_(&serde) {}
@@ -134,7 +136,8 @@
Function<void(Status)>&& on_error,
const Request&... request) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id, serde);
+ CallType call(
+ client.ClaimLocked(), channel_id, service_id, method_id, serde);
call.set_on_next_locked(std::move(on_next));
call.set_on_completed_locked(std::move(on_completed));
@@ -165,12 +168,13 @@
return *this;
}
- NanopbStreamResponseClientCall(internal::Endpoint& client,
+ NanopbStreamResponseClientCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type,
const NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: StreamResponseClientCall(
client, channel_id, service_id, method_id, type),
serde_(&serde) {}
@@ -254,11 +258,12 @@
private:
friend class internal::NanopbStreamResponseClientCall<Response>;
- NanopbClientReaderWriter(internal::Endpoint& client,
+ NanopbClientReaderWriter(internal::LockedEndpoint& client,
uint32_t channel_id_value,
uint32_t service_id,
uint32_t method_id,
const internal::NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbStreamResponseClientCall<Response>(
client,
channel_id_value,
@@ -294,11 +299,12 @@
private:
friend class internal::NanopbStreamResponseClientCall<Response>;
- NanopbClientReader(internal::Endpoint& client,
+ NanopbClientReader(internal::LockedEndpoint& client,
uint32_t channel_id_value,
uint32_t service_id,
uint32_t method_id,
const internal::NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbStreamResponseClientCall<Response>(
client,
channel_id_value,
@@ -338,11 +344,12 @@
private:
friend class internal::NanopbUnaryResponseClientCall<Response>;
- NanopbClientWriter(internal::Endpoint& client,
+ NanopbClientWriter(internal::LockedEndpoint& client,
uint32_t channel_id_value,
uint32_t service_id,
uint32_t method_id,
const internal::NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbUnaryResponseClientCall<Response>(
client,
channel_id_value,
@@ -376,11 +383,12 @@
private:
friend class internal::NanopbUnaryResponseClientCall<Response>;
- NanopbUnaryReceiver(internal::Endpoint& client,
+ NanopbUnaryReceiver(internal::LockedEndpoint& client,
uint32_t channel_id_value,
uint32_t service_id,
uint32_t method_id,
const internal::NanopbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbUnaryResponseClientCall<Response>(client,
channel_id_value,
service_id,
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb/internal/method.h b/pw_rpc/nanopb/public/pw_rpc/nanopb/internal/method.h
index a4e5a10..ebb35b3 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb/internal/method.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb/internal/method.h
@@ -384,7 +384,7 @@
template <typename Request>
static void ClientStreamingInvoker(const CallContext& context, const Packet&)
PW_UNLOCK_FUNCTION(rpc_lock()) {
- BaseNanopbServerReader<Request> reader(context,
+ BaseNanopbServerReader<Request> reader(context.ClaimLocked(),
MethodType::kClientStreaming);
rpc_lock().unlock();
static_cast<const NanopbMethod&>(context.method())
@@ -397,7 +397,7 @@
const Packet&)
PW_UNLOCK_FUNCTION(rpc_lock()) {
BaseNanopbServerReader<Request> reader_writer(
- context, MethodType::kBidirectionalStreaming);
+ context.ClaimLocked(), MethodType::kBidirectionalStreaming);
rpc_lock().unlock();
static_cast<const NanopbMethod&>(context.method())
.function_.stream_request(context.service(), reader_writer);
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb/server_reader_writer.h b/pw_rpc/nanopb/public/pw_rpc/nanopb/server_reader_writer.h
index 01ce663..7de6b88 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb/server_reader_writer.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb/server_reader_writer.h
@@ -43,7 +43,8 @@
public:
constexpr NanopbServerCall() : serde_(nullptr) {}
- NanopbServerCall(const CallContext& context, MethodType type);
+ NanopbServerCall(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
Status SendUnaryResponse(const void* payload, Status status)
PW_LOCKS_EXCLUDED(rpc_lock()) {
@@ -86,7 +87,9 @@
template <typename Request>
class BaseNanopbServerReader : public NanopbServerCall {
public:
- BaseNanopbServerReader(const internal::CallContext& context, MethodType type)
+ BaseNanopbServerReader(const internal::LockedCallContext& context,
+ MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: NanopbServerCall(context, type) {}
protected:
@@ -155,11 +158,13 @@
"The response type of a NanopbServerReaderWriter must match "
"the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetNanopbMethod<ServiceImpl,
- Info::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetNanopbMethod<ServiceImpl,
+ Info::kMethodId>())
+ .ClaimLocked()};
}
constexpr NanopbServerReaderWriter() = default;
@@ -197,7 +202,8 @@
template <typename, typename, uint32_t>
friend class internal::test::InvocationContext;
- NanopbServerReaderWriter(const internal::CallContext& context)
+ NanopbServerReaderWriter(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::BaseNanopbServerReader<Request>(
context, MethodType::kBidirectionalStreaming) {}
};
@@ -222,11 +228,13 @@
std::is_same_v<Response, typename Info::Response>,
"The response type of a NanopbServerReader must match the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kClientStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetNanopbMethod<ServiceImpl,
- Info::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kClientStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetNanopbMethod<ServiceImpl,
+ Info::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into which
@@ -254,7 +262,8 @@
template <typename, typename, uint32_t>
friend class internal::test::InvocationContext;
- NanopbServerReader(const internal::CallContext& context)
+ NanopbServerReader(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::BaseNanopbServerReader<Request>(
context, MethodType::kClientStreaming) {}
};
@@ -276,11 +285,13 @@
std::is_same_v<Response, typename Info::Response>,
"The response type of a NanopbServerWriter must match the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kServerStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetNanopbMethod<ServiceImpl,
- Info::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kServerStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetNanopbMethod<ServiceImpl,
+ Info::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into which
@@ -318,7 +329,8 @@
template <typename, typename, uint32_t>
friend class internal::test::InvocationContext;
- NanopbServerWriter(const internal::CallContext& context)
+ NanopbServerWriter(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbServerCall(context, MethodType::kServerStreaming) {}
};
@@ -337,11 +349,13 @@
std::is_same_v<Response, typename Info::Response>,
"The response type of a NanopbUnaryResponder must match the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kUnary>(
- channel_id,
- service,
- internal::MethodLookup::GetNanopbMethod<ServiceImpl,
- Info::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kUnary>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetNanopbMethod<ServiceImpl,
+ Info::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into which
@@ -375,7 +389,8 @@
template <typename, typename, uint32_t>
friend class internal::test::InvocationContext;
- NanopbUnaryResponder(const internal::CallContext& context)
+ NanopbUnaryResponder(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::NanopbServerCall(context, MethodType::kUnary) {}
};
diff --git a/pw_rpc/nanopb/server_reader_writer.cc b/pw_rpc/nanopb/server_reader_writer.cc
index f0517bc..3993ca3 100644
--- a/pw_rpc/nanopb/server_reader_writer.cc
+++ b/pw_rpc/nanopb/server_reader_writer.cc
@@ -18,7 +18,9 @@
namespace pw::rpc::internal {
-NanopbServerCall::NanopbServerCall(const CallContext& context, MethodType type)
+NanopbServerCall::NanopbServerCall(const LockedCallContext& context,
+ MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: internal::ServerCall(context, type),
serde_(&static_cast<const internal::NanopbMethod&>(context.method())
.serde()) {}
diff --git a/pw_rpc/public/pw_rpc/internal/call.h b/pw_rpc/public/pw_rpc/internal/call.h
index 24abd5e..6e72257 100644
--- a/pw_rpc/public/pw_rpc/internal/call.h
+++ b/pw_rpc/public/pw_rpc/internal/call.h
@@ -37,6 +37,7 @@
namespace internal {
class Endpoint;
+class LockedEndpoint;
class Packet;
// Internal RPC Call class. The Call is used to respond to any type of RPC.
@@ -215,21 +216,15 @@
{}
// Creates an active server-side Call.
- Call(const CallContext& context, MethodType type)
- : Call(context.server(),
- context.call_id(),
- context.channel_id(),
- UnwrapServiceId(context.service().service_id()),
- context.method().id(),
- type,
- kServerCall) {}
+ Call(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
// Creates an active client-side Call.
- Call(Endpoint& client,
+ Call(LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
- MethodType type);
+ MethodType type) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
// This call must be in a closed state when this is called.
void MoveFrom(Call& other) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
@@ -295,7 +290,7 @@
enum CallType : bool { kServerCall, kClientCall };
// Common constructor for server & client calls.
- Call(Endpoint& endpoint,
+ Call(LockedEndpoint& endpoint,
uint32_t id,
uint32_t channel_id,
uint32_t service_id,
diff --git a/pw_rpc/public/pw_rpc/internal/call_context.h b/pw_rpc/public/pw_rpc/internal/call_context.h
index cc15cbd..666abf0 100644
--- a/pw_rpc/public/pw_rpc/internal/call_context.h
+++ b/pw_rpc/public/pw_rpc/internal/call_context.h
@@ -17,6 +17,7 @@
#include <cstdint>
#include "pw_rpc/internal/channel.h"
+#include "pw_rpc/internal/lock.h"
namespace pw::rpc {
@@ -25,6 +26,8 @@
namespace internal {
class Endpoint;
+class LockedCallContext;
+class LockedEndpoint;
class Method;
// The Server creates a CallContext object to represent a method invocation. The
@@ -42,6 +45,17 @@
method_(method),
call_id_(call_id) {}
+ // Claims that `rpc_lock()` is held, returning a wrapped context.
+ //
+ // This function should only be called in contexts in which it is clear that
+ // `rpc_lock()` is held. When calling this function from a constructor, the
+ // lock annotation will not result in errors, so care should be taken to
+ // ensure that `rpc_lock()` is held.
+ const LockedCallContext& ClaimLocked() const
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+
+ LockedCallContext& ClaimLocked() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+
constexpr Endpoint& server() const { return server_; }
constexpr const uint32_t& channel_id() const { return channel_id_; }
@@ -63,5 +77,27 @@
uint32_t call_id_;
};
+// A `CallContext` indicating that `rpc_lock()` is held.
+//
+// This is used as a constructor argument to supplement
+// `PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())`. Current compilers do not enforce
+// lock annotations on constructors; no warnings or errors are produced when
+// calling an annotated constructor without holding `rpc_lock()`.
+class LockedCallContext : public CallContext {
+ public:
+ friend class CallContext;
+ // No public constructor: this is created only via the `ClaimLocked` method on
+ // `CallContext`.
+ constexpr LockedCallContext() = delete;
+};
+
+inline const LockedCallContext& CallContext::ClaimLocked() const {
+ return *static_cast<const LockedCallContext*>(this);
+}
+
+inline LockedCallContext& CallContext::ClaimLocked() {
+ return *static_cast<LockedCallContext*>(this);
+}
+
} // namespace internal
} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/internal/client_call.h b/pw_rpc/public/pw_rpc/internal/client_call.h
index 02d46f1..2120933 100644
--- a/pw_rpc/public/pw_rpc/internal/client_call.h
+++ b/pw_rpc/public/pw_rpc/internal/client_call.h
@@ -18,6 +18,7 @@
#include "pw_bytes/span.h"
#include "pw_function/function.h"
#include "pw_rpc/internal/call.h"
+#include "pw_rpc/internal/endpoint.h"
#include "pw_rpc/internal/lock.h"
namespace pw::rpc::internal {
@@ -34,11 +35,11 @@
protected:
constexpr ClientCall() = default;
- ClientCall(Endpoint& client,
+ ClientCall(LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
- MethodType type)
+ MethodType type) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: Call(client, channel_id, service_id, method_id, type) {}
// Sends CLIENT_STREAM_END if applicable, releases any held payload buffer,
@@ -65,7 +66,7 @@
Function<void(Status)>&& on_error,
ConstByteSpan request) PW_LOCKS_EXCLUDED(rpc_lock()) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id);
+ CallType call(client.ClaimLocked(), channel_id, service_id, method_id);
call.set_on_completed_locked(std::move(on_completed));
call.set_on_error_locked(std::move(on_error));
@@ -87,11 +88,12 @@
protected:
constexpr UnaryResponseClientCall() = default;
- UnaryResponseClientCall(Endpoint& client,
+ UnaryResponseClientCall(LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: ClientCall(client, channel_id, service_id, method_id, type) {}
UnaryResponseClientCall(UnaryResponseClientCall&& other) {
@@ -144,7 +146,7 @@
Function<void(Status)>&& on_error,
ConstByteSpan request) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id);
+ CallType call(client.ClaimLocked(), channel_id, service_id, method_id);
call.set_on_next_locked(std::move(on_next));
call.set_on_completed_locked(std::move(on_completed));
@@ -169,11 +171,12 @@
protected:
constexpr StreamResponseClientCall() = default;
- StreamResponseClientCall(Endpoint& client,
+ StreamResponseClientCall(LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: ClientCall(client, channel_id, service_id, method_id, type) {}
StreamResponseClientCall(StreamResponseClientCall&& other) {
diff --git a/pw_rpc/public/pw_rpc/internal/endpoint.h b/pw_rpc/public/pw_rpc/internal/endpoint.h
index ea44409..c8caca5 100644
--- a/pw_rpc/public/pw_rpc/internal/endpoint.h
+++ b/pw_rpc/public/pw_rpc/internal/endpoint.h
@@ -25,6 +25,8 @@
namespace pw::rpc::internal {
+class LockedEndpoint;
+
// Manages a list of channels and a list of ongoing calls for either a server or
// client.
//
@@ -37,6 +39,14 @@
public:
~Endpoint();
+ // Claims that `rpc_lock()` is held, returning a wrapped endpoint.
+ //
+ // This function should only be called in contexts in which it is clear that
+ // `rpc_lock()` is held. When calling this function from a constructor, the
+ // lock annotation will not result in errors, so care should be taken to
+ // ensure that `rpc_lock()` is held.
+ LockedEndpoint& ClaimLocked() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+
// Creates a channel with the provided ID and ChannelOutput, if a channel slot
// is available or can be allocated (if PW_RPC_DYNAMIC_ALLOCATION is enabled).
// Returns:
@@ -140,4 +150,22 @@
uint32_t next_call_id_ PW_GUARDED_BY(rpc_lock());
};
+// An `Endpoint` indicating that `rpc_lock()` is held.
+//
+// This is used as a constructor argument to supplement
+// `PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())`. Current compilers do not enforce
+// lock annotations on constructors; no warnings or errors are produced when
+// calling an annotated constructor without holding `rpc_lock()`.
+class LockedEndpoint : public Endpoint {
+ public:
+ friend class Endpoint;
+ // No public constructor: this is created only via the `ClaimLocked` method on
+ // `Endpoint`.
+ constexpr LockedEndpoint() = delete;
+};
+
+inline LockedEndpoint& Endpoint::ClaimLocked() {
+ return *static_cast<LockedEndpoint*>(this);
+}
+
} // namespace pw::rpc::internal
diff --git a/pw_rpc/public/pw_rpc/internal/server_call.h b/pw_rpc/public/pw_rpc/internal/server_call.h
index e4b2cbf..c80e063 100644
--- a/pw_rpc/public/pw_rpc/internal/server_call.h
+++ b/pw_rpc/public/pw_rpc/internal/server_call.h
@@ -40,7 +40,7 @@
ServerCall(ServerCall&& other) { *this = std::move(other); }
- ~ServerCall() {
+ ~ServerCall() PW_LOCKS_EXCLUDED(rpc_lock()) {
// Any errors are logged in Channel::Send.
CloseAndSendResponse(OkStatus()).IgnoreError();
}
@@ -55,7 +55,8 @@
void MoveServerCallFrom(ServerCall& other)
PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
- ServerCall(const CallContext& context, MethodType type)
+ ServerCall(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: Call(context, type) {}
// set_on_client_stream_end is templated so that it can be conditionally
diff --git a/pw_rpc/public/pw_rpc/internal/test_method.h b/pw_rpc/public/pw_rpc/internal/test_method.h
index 1406c8a..409936b 100644
--- a/pw_rpc/public/pw_rpc/internal/test_method.h
+++ b/pw_rpc/public/pw_rpc/internal/test_method.h
@@ -34,7 +34,8 @@
class FakeServerCall : public ServerCall {
public:
constexpr FakeServerCall() = default;
- FakeServerCall(const CallContext& context, MethodType type)
+ FakeServerCall(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: ServerCall(context, type) {}
FakeServerCall(FakeServerCall&&) = default;
@@ -70,7 +71,7 @@
test_method.invocations_ += 1;
// Create a call object so it registers / unregisters with the server.
- FakeServerCall fake_call(context, kType);
+ FakeServerCall fake_call(context.ClaimLocked(), kType);
rpc_lock().unlock();
diff --git a/pw_rpc/public/pw_rpc/internal/test_method_context.h b/pw_rpc/public/pw_rpc/internal/test_method_context.h
index f96efdf..6b39edb 100644
--- a/pw_rpc/public/pw_rpc/internal/test_method_context.h
+++ b/pw_rpc/public/pw_rpc/internal/test_method_context.h
@@ -17,6 +17,7 @@
#include "pw_assert/assert.h"
#include "pw_rpc/channel.h"
+#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/fake_channel_output.h"
#include "pw_rpc/internal/method.h"
#include "pw_rpc/internal/packet.h"
@@ -139,7 +140,7 @@
// Invokes the RPC, optionally with a request argument.
template <auto kMethod, typename T, typename... RequestArg>
- void call(RequestArg&&... request) {
+ void call(RequestArg&&... request) PW_LOCKS_EXCLUDED(rpc_lock()) {
static_assert(sizeof...(request) <= 1);
output_.clear();
T responder = GetResponder<T>();
@@ -148,8 +149,9 @@
}
template <typename T>
- T GetResponder() {
- return T(call_context());
+ T GetResponder() PW_LOCKS_EXCLUDED(rpc_lock()) {
+ std::lock_guard lock(rpc_lock());
+ return T(call_context().ClaimLocked());
}
const internal::CallContext& call_context() const { return context_; }
diff --git a/pw_rpc/pw_rpc_private/fake_server_reader_writer.h b/pw_rpc/pw_rpc_private/fake_server_reader_writer.h
index 44c359b..fe9fcd5 100644
--- a/pw_rpc/pw_rpc_private/fake_server_reader_writer.h
+++ b/pw_rpc/pw_rpc_private/fake_server_reader_writer.h
@@ -36,19 +36,24 @@
//
// Call's public API is intended for rpc::Server, so hide the public methods
// with private inheritance.
-class FakeServerReaderWriter : private internal::ServerCall {
+class FakeServerReaderWriter : private ServerCall {
public:
constexpr FakeServerReaderWriter() = default;
// On a real reader/writer, this constructor would not be exposed.
- FakeServerReaderWriter(const CallContext& context,
+ FakeServerReaderWriter(const LockedCallContext& context,
MethodType type = MethodType::kBidirectionalStreaming)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: ServerCall(context, type) {}
FakeServerReaderWriter(FakeServerReaderWriter&&) = default;
FakeServerReaderWriter& operator=(FakeServerReaderWriter&&) = default;
// Pull in protected functions from the hidden Call base as needed.
+ //
+ // Note: these functions all acquire `rpc_lock()`. However, the
+ // `PW_LOCKS_EXCLUDED(rpc_lock())` on their original definitions does not
+ // appear to carry through here.
using Call::active;
using Call::set_on_error;
using Call::set_on_next;
@@ -68,9 +73,11 @@
public:
constexpr FakeServerWriter() = default;
- FakeServerWriter(const CallContext& context)
+ FakeServerWriter(const LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: FakeServerReaderWriter(context, MethodType::kServerStreaming) {}
FakeServerWriter(FakeServerWriter&&) = default;
+ FakeServerWriter& operator=(FakeServerWriter&&) = default;
// Common reader/writer functions.
using FakeServerReaderWriter::active;
@@ -86,10 +93,12 @@
public:
constexpr FakeServerReader() = default;
- FakeServerReader(const CallContext& context)
+ FakeServerReader(const LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: FakeServerReaderWriter(context, MethodType::kClientStreaming) {}
FakeServerReader(FakeServerReader&&) = default;
+ FakeServerReader& operator=(FakeServerReader&&) = default;
using FakeServerReaderWriter::active;
using FakeServerReaderWriter::as_server_call;
diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h
index 6938bd4..ec59044 100644
--- a/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h
+++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h
@@ -44,7 +44,8 @@
Function<void(Status)>&& on_error,
const Request&... request) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id, serde);
+ CallType call(
+ client.ClaimLocked(), channel_id, service_id, method_id, serde);
call.set_on_completed_locked(std::move(on_completed));
call.set_on_error_locked(std::move(on_error));
@@ -67,12 +68,13 @@
// variable into which to move client reader/writers from RPC calls.
constexpr PwpbUnaryResponseClientCall() = default;
- PwpbUnaryResponseClientCall(internal::Endpoint& client,
+ PwpbUnaryResponseClientCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type,
const PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: UnaryResponseClientCall(
client, channel_id, service_id, method_id, type),
serde_(&serde) {}
@@ -171,7 +173,8 @@
Function<void(Status)>&& on_error,
const Request&... request) {
rpc_lock().lock();
- CallType call(client, channel_id, service_id, method_id, serde);
+ CallType call(
+ client.ClaimLocked(), channel_id, service_id, method_id, serde);
call.set_on_next_locked(std::move(on_next));
call.set_on_completed_locked(std::move(on_completed));
@@ -194,12 +197,13 @@
// variable into which to move client reader/writers from RPC calls.
constexpr PwpbStreamResponseClientCall() = default;
- PwpbStreamResponseClientCall(internal::Endpoint& client,
+ PwpbStreamResponseClientCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type,
const PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: StreamResponseClientCall(
client, channel_id, service_id, method_id, type),
serde_(&serde) {}
@@ -320,11 +324,12 @@
protected:
friend class internal::PwpbStreamResponseClientCall<Response>;
- PwpbClientReaderWriter(internal::Endpoint& client,
+ PwpbClientReaderWriter(internal::LockedEndpoint& client,
uint32_t channel_id_v,
uint32_t service_id,
uint32_t method_id,
const internal::PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::PwpbStreamResponseClientCall<Response>(
client,
channel_id_v,
@@ -363,11 +368,12 @@
private:
friend class internal::PwpbStreamResponseClientCall<Response>;
- PwpbClientReader(internal::Endpoint& client,
+ PwpbClientReader(internal::LockedEndpoint& client,
uint32_t channel_id_v,
uint32_t service_id,
uint32_t method_id,
const internal::PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::PwpbStreamResponseClientCall<Response>(
client,
channel_id_v,
@@ -419,11 +425,13 @@
private:
friend class internal::PwpbUnaryResponseClientCall<Response>;
- PwpbClientWriter(internal::Endpoint& client,
+ PwpbClientWriter(internal::LockedEndpoint& client,
uint32_t channel_id_v,
uint32_t service_id,
uint32_t method_id,
const internal::PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
+
: internal::PwpbUnaryResponseClientCall<Response>(
client,
channel_id_v,
@@ -461,11 +469,12 @@
private:
friend class internal::PwpbUnaryResponseClientCall<Response>;
- PwpbUnaryReceiver(internal::Endpoint& client,
+ PwpbUnaryReceiver(internal::LockedEndpoint& client,
uint32_t channel_id_v,
uint32_t service_id,
uint32_t method_id,
const internal::PwpbMethodSerde& serde)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::PwpbUnaryResponseClientCall<Response>(client,
channel_id_v,
service_id,
diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h
index 1a23b81..4e46358 100644
--- a/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h
+++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h
@@ -251,7 +251,8 @@
return;
}
- internal::PwpbServerCall responder(context, MethodType::kUnary);
+ internal::PwpbServerCall responder(context.ClaimLocked(),
+ MethodType::kUnary);
rpc_lock().unlock();
const Status status = function_.synchronous_unary(
context.service(), &request_struct, &response_struct);
@@ -269,7 +270,7 @@
return;
}
- internal::PwpbServerCall server_writer(context, method_type);
+ internal::PwpbServerCall server_writer(context.ClaimLocked(), method_type);
rpc_lock().unlock();
function_.unary_request(context.service(), &request_struct, server_writer);
}
@@ -337,7 +338,7 @@
static void ClientStreamingInvoker(const CallContext& context, const Packet&)
PW_UNLOCK_FUNCTION(rpc_lock()) {
internal::BasePwpbServerReader<Request> reader(
- context, MethodType::kClientStreaming);
+ context.ClaimLocked(), MethodType::kClientStreaming);
rpc_lock().unlock();
static_cast<const PwpbMethod&>(context.method())
.function_.stream_request(context.service(), reader);
@@ -349,7 +350,7 @@
const Packet&)
PW_UNLOCK_FUNCTION(rpc_lock()) {
internal::BasePwpbServerReader<Request> reader_writer(
- context, MethodType::kBidirectionalStreaming);
+ context.ClaimLocked(), MethodType::kBidirectionalStreaming);
rpc_lock().unlock();
static_cast<const PwpbMethod&>(context.method())
.function_.stream_request(context.service(), reader_writer);
diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h
index 47b4802..8a5ff1d 100644
--- a/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h
+++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h
@@ -45,7 +45,8 @@
public:
// Allow construction using a call context and method type which creates
// a working server call.
- PwpbServerCall(const CallContext& context, MethodType type);
+ PwpbServerCall(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock());
// Sends a unary response.
// Returns the following Status codes:
@@ -125,7 +126,8 @@
template <typename Request>
class BasePwpbServerReader : public PwpbServerCall {
public:
- BasePwpbServerReader(const CallContext& context, MethodType type)
+ BasePwpbServerReader(const LockedCallContext& context, MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock())
: PwpbServerCall(context, type) {}
protected:
@@ -204,11 +206,14 @@
"The response type of a PwpbServerReaderWriter must match "
"the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetPwpbMethod<ServiceImpl,
- MethodInfo::kMethodId>())};
+ return {
+ server
+ .OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetPwpbMethod<ServiceImpl,
+ MethodInfo::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into
@@ -248,8 +253,9 @@
friend class internal::PwpbMethod;
- PwpbServerReaderWriter(const internal::CallContext& context,
+ PwpbServerReaderWriter(const internal::LockedCallContext& context,
MethodType type = MethodType::kBidirectionalStreaming)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::BasePwpbServerReader<Request>(context, type) {}
};
@@ -276,11 +282,14 @@
"The response type of a PwpbServerReader must match "
"the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kClientStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetPwpbMethod<ServiceImpl,
- MethodInfo::kMethodId>())};
+ return {
+ server
+ .OpenContext<kMethod, MethodType::kClientStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetPwpbMethod<ServiceImpl,
+ MethodInfo::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into
@@ -316,7 +325,8 @@
friend class internal::PwpbMethod;
- PwpbServerReader(const internal::CallContext& context)
+ PwpbServerReader(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::BasePwpbServerReader<Request>(context,
MethodType::kClientStreaming) {}
};
@@ -341,11 +351,14 @@
"The response type of a PwpbServerWriter must match "
"the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kServerStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetPwpbMethod<ServiceImpl,
- MethodInfo::kMethodId>())};
+ return {
+ server
+ .OpenContext<kMethod, MethodType::kServerStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetPwpbMethod<ServiceImpl,
+ MethodInfo::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into
@@ -384,7 +397,8 @@
friend class internal::PwpbMethod;
- PwpbServerWriter(const internal::CallContext& context)
+ PwpbServerWriter(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::PwpbServerCall(context, MethodType::kServerStreaming) {}
};
@@ -408,11 +422,14 @@
"The response type of a PwpbUnaryResponder must match "
"the method.");
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kUnary>(
- channel_id,
- service,
- internal::MethodLookup::GetPwpbMethod<ServiceImpl,
- MethodInfo::kMethodId>())};
+ return {
+ server
+ .OpenContext<kMethod, MethodType::kUnary>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetPwpbMethod<ServiceImpl,
+ MethodInfo::kMethodId>())
+ .ClaimLocked()};
}
// Allow default construction so that users can declare a variable into
@@ -447,7 +464,8 @@
friend class internal::PwpbMethod;
- PwpbUnaryResponder(const internal::CallContext& context)
+ PwpbUnaryResponder(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::PwpbServerCall(context, MethodType::kUnary) {}
};
diff --git a/pw_rpc/pwpb/server_reader_writer.cc b/pw_rpc/pwpb/server_reader_writer.cc
index 311a058..3469862 100644
--- a/pw_rpc/pwpb/server_reader_writer.cc
+++ b/pw_rpc/pwpb/server_reader_writer.cc
@@ -18,7 +18,8 @@
namespace pw::rpc::internal {
-PwpbServerCall::PwpbServerCall(const CallContext& context, MethodType type)
+PwpbServerCall::PwpbServerCall(const LockedCallContext& context,
+ MethodType type)
: internal::ServerCall(context, type),
serde_(&static_cast<const PwpbMethod&>(context.method()).serde()) {}
diff --git a/pw_rpc/raw/client_test.cc b/pw_rpc/raw/client_test.cc
index a814a3c..9dcf32d 100644
--- a/pw_rpc/raw/client_test.cc
+++ b/pw_rpc/raw/client_test.cc
@@ -43,8 +43,9 @@
namespace {
template <auto kMethod, typename Call, typename Context>
-Call MakeCall(Context& context) {
- return Call(context.client(),
+Call MakeCall(Context& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) {
+ return Call(context.client().ClaimLocked(),
context.channel().id(),
internal::MethodInfo<kMethod>::kServiceId,
internal::MethodInfo<kMethod>::kMethodId,
@@ -53,19 +54,20 @@
class TestStreamCall : public internal::StreamResponseClientCall {
public:
- TestStreamCall(Client& client,
+ TestStreamCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: StreamResponseClientCall(
client, channel_id, service_id, method_id, type),
payload(nullptr) {
- set_on_next([this](ConstByteSpan string) {
+ set_on_next_locked([this](ConstByteSpan string) {
payload = reinterpret_cast<const char*>(string.data());
});
- set_on_completed([this](Status status) { completed = status; });
- set_on_error([this](Status status) { error = status; });
+ set_on_completed_locked([this](Status status) { completed = status; });
+ set_on_error_locked([this](Status status) { error = status; });
}
const char* payload;
@@ -75,19 +77,20 @@
class TestUnaryCall : public internal::UnaryResponseClientCall {
public:
- TestUnaryCall(Client& client,
+ TestUnaryCall(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: UnaryResponseClientCall(
client, channel_id, service_id, method_id, type),
payload(nullptr) {
- set_on_completed([this](ConstByteSpan string, Status status) {
+ set_on_completed_locked([this](ConstByteSpan string, Status status) {
payload = reinterpret_cast<const char*>(string.data());
completed = status;
});
- set_on_error([this](Status status) { error = status; });
+ set_on_error_locked([this](Status status) { error = status; });
}
const char* payload;
@@ -97,8 +100,8 @@
TEST(Client, ProcessPacket_InvokesUnaryCallbacks) {
RawClientTestContext context;
- TestUnaryCall call = MakeCall<UnaryMethod, TestUnaryCall>(context);
internal::rpc_lock().lock();
+ TestUnaryCall call = MakeCall<UnaryMethod, TestUnaryCall>(context);
call.SendInitialClientRequest({});
ASSERT_NE(call.completed, OkStatus());
@@ -113,8 +116,8 @@
TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
RawClientTestContext context;
- auto call = MakeCall<BidirectionalStreamMethod, TestStreamCall>(context);
internal::rpc_lock().lock();
+ auto call = MakeCall<BidirectionalStreamMethod, TestStreamCall>(context);
call.SendInitialClientRequest({});
context.server().SendServerStream<BidirectionalStreamMethod>(
@@ -130,8 +133,8 @@
TEST(Client, ProcessPacket_InvokesErrorCallback) {
RawClientTestContext context;
- auto call = MakeCall<BidirectionalStreamMethod, TestStreamCall>(context);
internal::rpc_lock().lock();
+ auto call = MakeCall<BidirectionalStreamMethod, TestStreamCall>(context);
call.SendInitialClientRequest({});
context.server().SendServerError<BidirectionalStreamMethod>(
@@ -198,8 +201,8 @@
TEST(Client, CloseChannel_CallsErrorCallback) {
RawClientTestContext ctx;
- TestUnaryCall call = MakeCall<UnaryMethod, TestUnaryCall>(ctx);
internal::rpc_lock().lock();
+ TestUnaryCall call = MakeCall<UnaryMethod, TestUnaryCall>(ctx);
call.SendInitialClientRequest({});
ASSERT_NE(call.completed, OkStatus());
diff --git a/pw_rpc/raw/method.cc b/pw_rpc/raw/method.cc
index 2b976f4..0bff460 100644
--- a/pw_rpc/raw/method.cc
+++ b/pw_rpc/raw/method.cc
@@ -23,7 +23,7 @@
void RawMethod::SynchronousUnaryInvoker(const CallContext& context,
const Packet& request) {
- RawUnaryResponder responder(context);
+ RawUnaryResponder responder(context.ClaimLocked());
rpc_lock().unlock();
// TODO(hepler): Remove support for raw synchronous unary methods. Unlike
// synchronous Nanopb methods, they provide little value compared to
@@ -42,7 +42,7 @@
void RawMethod::AsynchronousUnaryInvoker(const CallContext& context,
const Packet& request) {
- RawUnaryResponder responder(context);
+ RawUnaryResponder responder(context.ClaimLocked());
rpc_lock().unlock();
static_cast<const RawMethod&>(context.method())
.function_.asynchronous_unary(
@@ -51,7 +51,7 @@
void RawMethod::ServerStreamingInvoker(const CallContext& context,
const Packet& request) {
- RawServerWriter server_writer(context);
+ RawServerWriter server_writer(context.ClaimLocked());
rpc_lock().unlock();
static_cast<const RawMethod&>(context.method())
.function_.server_streaming(
@@ -60,7 +60,7 @@
void RawMethod::ClientStreamingInvoker(const CallContext& context,
const Packet&) {
- RawServerReader reader(context);
+ RawServerReader reader(context.ClaimLocked());
rpc_lock().unlock();
static_cast<const RawMethod&>(context.method())
.function_.stream_request(context.service(), reader);
@@ -68,7 +68,7 @@
void RawMethod::BidirectionalStreamingInvoker(const CallContext& context,
const Packet&) {
- RawServerReaderWriter reader_writer(context);
+ RawServerReaderWriter reader_writer(context.ClaimLocked());
rpc_lock().unlock();
static_cast<const RawMethod&>(context.method())
.function_.stream_request(context.service(), reader_writer);
diff --git a/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h b/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
index 5e21a59..1dbae7d 100644
--- a/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
+++ b/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
@@ -59,11 +59,12 @@
protected:
friend class internal::StreamResponseClientCall;
- RawClientReaderWriter(internal::Endpoint& client,
+ RawClientReaderWriter(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
MethodType type = MethodType::kBidirectionalStreaming)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: StreamResponseClientCall(
client, channel_id, service_id, method_id, type) {}
};
@@ -88,10 +89,11 @@
private:
friend class internal::StreamResponseClientCall;
- RawClientReader(internal::Endpoint& client,
+ RawClientReader(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: StreamResponseClientCall(client,
channel_id,
service_id,
@@ -124,10 +126,11 @@
private:
friend class internal::UnaryResponseClientCall;
- RawClientWriter(internal::Endpoint& client,
+ RawClientWriter(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: UnaryResponseClientCall(client,
channel_id,
service_id,
@@ -154,10 +157,11 @@
private:
friend class internal::UnaryResponseClientCall;
- RawUnaryReceiver(internal::Endpoint& client,
+ RawUnaryReceiver(internal::LockedEndpoint& client,
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: UnaryResponseClientCall(
client, channel_id, service_id, method_id, MethodType::kUnary) {}
};
diff --git a/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h b/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h
index 8ff018e..9312757 100644
--- a/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h
+++ b/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h
@@ -59,12 +59,14 @@
uint32_t channel_id,
ServiceImpl& service) {
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetRawMethod<
- ServiceImpl,
- internal::MethodInfo<kMethod>::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kBidirectionalStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetRawMethod<
+ ServiceImpl,
+ internal::MethodInfo<kMethod>::kMethodId>())
+ .ClaimLocked()};
}
using internal::Call::active;
@@ -87,8 +89,9 @@
using internal::Call::operator const Writer&;
protected:
- RawServerReaderWriter(const internal::CallContext& context,
+ RawServerReaderWriter(const internal::LockedCallContext& context,
MethodType type = MethodType::kBidirectionalStreaming)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: internal::ServerCall(context, type) {}
using internal::Call::CloseAndSendResponse;
@@ -112,12 +115,14 @@
uint32_t channel_id,
ServiceImpl& service) {
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kClientStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetRawMethod<
- ServiceImpl,
- internal::MethodInfo<kMethod>::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kClientStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetRawMethod<
+ ServiceImpl,
+ internal::MethodInfo<kMethod>::kMethodId>())
+ .ClaimLocked()};
}
constexpr RawServerReader() = default;
@@ -142,7 +147,8 @@
template <typename, typename, uint32_t>
friend class internal::test::InvocationContext;
- RawServerReader(const internal::CallContext& context)
+ RawServerReader(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: RawServerReaderWriter(context, MethodType::kClientStreaming) {}
};
@@ -157,12 +163,14 @@
uint32_t channel_id,
ServiceImpl& service) {
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kServerStreaming>(
- channel_id,
- service,
- internal::MethodLookup::GetRawMethod<
- ServiceImpl,
- internal::MethodInfo<kMethod>::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kServerStreaming>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetRawMethod<
+ ServiceImpl,
+ internal::MethodInfo<kMethod>::kMethodId>())
+ .ClaimLocked()};
}
constexpr RawServerWriter() = default;
@@ -188,7 +196,8 @@
friend class internal::RawMethod;
- RawServerWriter(const internal::CallContext& context)
+ RawServerWriter(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: RawServerReaderWriter(context, MethodType::kServerStreaming) {}
};
@@ -203,12 +212,14 @@
uint32_t channel_id,
ServiceImpl& service) {
internal::LockGuard lock(internal::rpc_lock());
- return {server.OpenContext<kMethod, MethodType::kUnary>(
- channel_id,
- service,
- internal::MethodLookup::GetRawMethod<
- ServiceImpl,
- internal::MethodInfo<kMethod>::kMethodId>())};
+ return {server
+ .OpenContext<kMethod, MethodType::kUnary>(
+ channel_id,
+ service,
+ internal::MethodLookup::GetRawMethod<
+ ServiceImpl,
+ internal::MethodInfo<kMethod>::kMethodId>())
+ .ClaimLocked()};
}
constexpr RawUnaryResponder() = default;
@@ -231,7 +242,8 @@
friend class internal::RawMethod;
- RawUnaryResponder(const internal::CallContext& context)
+ RawUnaryResponder(const internal::LockedCallContext& context)
+ PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
: RawServerReaderWriter(context, MethodType::kUnary) {}
};
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index d69fa30..2814d40 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -367,12 +367,16 @@
class BidiMethod : public BasicServer {
protected:
- BidiMethod()
- : responder_(internal::CallContext(server_,
- channels_[0].id(),
- service_42_,
- service_42_.method(100),
- 0)) {
+ BidiMethod() {
+ internal::rpc_lock().lock();
+ internal::CallContext context(
+ server_, channels_[0].id(), service_42_, service_42_.method(100), 0);
+ // A local temporary is required since the constructor requires a lock,
+ // but the *move* constructor takes out the lock.
+ internal::test::FakeServerReaderWriter responder_temp(
+ context.ClaimLocked());
+ internal::rpc_lock().unlock();
+ responder_ = std::move(responder_temp);
PW_CHECK(responder_.active());
}
@@ -513,17 +517,16 @@
class ServerStreamingMethod : public BasicServer {
protected:
- ServerStreamingMethod()
- : call_(server_,
- channels_[0].id(),
- service_42_,
- service_42_.method(100),
- 0),
- responder_(call_) {
+ ServerStreamingMethod() {
+ internal::CallContext context(
+ server_, channels_[0].id(), service_42_, service_42_.method(100), 0);
+ internal::rpc_lock().lock();
+ internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
+ internal::rpc_lock().unlock();
+ responder_ = std::move(responder_temp);
PW_CHECK(responder_.active());
}
- internal::CallContext call_;
internal::test::FakeServerWriter responder_;
};