pw_rpc: Check Nanopb request/response types
Ensure that Nanopb method implementations use the correct request and
response types. Using the wrong struct type will be a compilation error.
Change-Id: I9d300a108d76d014cb7699a72b7526b5b3e2d82e
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/33701
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc
index a9055bb..ca2653c 100644
--- a/pw_rpc/nanopb/nanopb_method_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_test.cc
@@ -61,6 +61,64 @@
ServerWriter<FakePb>&) {}
};
+// Test that the matches() function matches valid signatures.
+static_assert(NanopbMethod::template matches<&TestNanopbService::Unary,
+ FakePb,
+ FakePb>());
+static_assert(
+ NanopbMethod::template matches<&TestNanopbService::ServerStreaming,
+ FakePb,
+ FakePb>());
+static_assert(NanopbMethod::template matches<&TestNanopbService::StaticUnary,
+ FakePb,
+ FakePb>());
+static_assert(
+ NanopbMethod::template matches<&TestNanopbService::StaticServerStreaming,
+ FakePb,
+ FakePb>());
+
+// Test that the matches() function does not match the wrong method type.
+static_assert(!NanopbMethod::template matches<&TestNanopbService::UnaryWrongArg,
+ FakePb,
+ FakePb>());
+static_assert(
+ !NanopbMethod::template matches<&TestNanopbService::StaticUnaryVoidReturn,
+ FakePb,
+ FakePb>());
+static_assert(!NanopbMethod::template matches<
+ &TestNanopbService::ServerStreamingBadReturn,
+ FakePb,
+ FakePb>());
+static_assert(!NanopbMethod::template matches<
+ &TestNanopbService::StaticServerStreamingMissingArg,
+ FakePb,
+ FakePb>());
+
+struct WrongPb;
+
+// Test matches() rejects incorrect request/response types.
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+ WrongPb,
+ FakePb>());
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+ FakePb,
+ WrongPb>());
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+ WrongPb,
+ WrongPb>());
+static_assert(
+ !NanopbMethod::template matches<&TestNanopbService::ServerStreaming,
+ WrongPb,
+ FakePb>());
+
+static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticUnary,
+ FakePb,
+ WrongPb>());
+static_assert(
+ !NanopbMethod::template matches<&TestNanopbService::StaticServerStreaming,
+ FakePb,
+ WrongPb>());
+
TEST(MethodImplTester, NanopbMethod) {
constexpr MethodImplTester<NanopbMethod, TestNanopbService, nullptr, nullptr>
method_tester;
diff --git a/pw_rpc/nanopb/nanopb_method_union_test.cc b/pw_rpc/nanopb/nanopb_method_union_test.cc
index ab754e4..d1bfd7f 100644
--- a/pw_rpc/nanopb/nanopb_method_union_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_union_test.cc
@@ -32,15 +32,25 @@
constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {}
static constexpr std::array<NanopbMethodUnion, 4> kMethods = {
- GetNanopbOrRawMethodFor<&Implementation::DoNothing, MethodType::kUnary>(
+ GetNanopbOrRawMethodFor<&Implementation::DoNothing,
+ MethodType::kUnary,
+ pw_rpc_test_Empty,
+ pw_rpc_test_Empty>(
10u, pw_rpc_test_Empty_fields, pw_rpc_test_Empty_fields),
GetNanopbOrRawMethodFor<&Implementation::RawStream,
- MethodType::kServerStreaming>(
+ MethodType::kServerStreaming,
+ pw_rpc_test_TestRequest,
+ pw_rpc_test_TestResponse>(
11u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
- GetNanopbOrRawMethodFor<&Implementation::AddFive, MethodType::kUnary>(
+ GetNanopbOrRawMethodFor<&Implementation::AddFive,
+ MethodType::kUnary,
+ pw_rpc_test_TestRequest,
+ pw_rpc_test_TestResponse>(
12u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
GetNanopbOrRawMethodFor<&Implementation::StartStream,
- MethodType::kServerStreaming>(
+ MethodType::kServerStreaming,
+ pw_rpc_test_TestRequest,
+ pw_rpc_test_TestResponse>(
13u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
};
};
diff --git a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
index 2a20600..dc4d45b 100644
--- a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
+++ b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
@@ -119,9 +119,11 @@
// structs.
class NanopbMethod : public Method {
public:
- template <auto method>
+ template <auto method, typename RequestType, typename ResponseType>
static constexpr bool matches() {
- return std::is_same_v<MethodImplementation<method>, NanopbMethod>;
+ return std::is_same_v<MethodImplementation<method>, NanopbMethod> &&
+ std::is_same_v<RequestType, Request<method>> &&
+ std::is_same_v<ResponseType, Response<method>>;
}
// Creates a NanopbMethod for a unary RPC.
diff --git a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
index d73813b..da4ec30 100644
--- a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
+++ b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
@@ -42,14 +42,14 @@
// Returns either a raw or nanopb method object, depending on the implemented
// function's signature.
-template <auto method, MethodType type>
+template <auto method, MethodType type, typename Request, typename Response>
constexpr auto GetNanopbOrRawMethodFor(
uint32_t id,
[[maybe_unused]] NanopbMessageDescriptor request_fields,
[[maybe_unused]] NanopbMessageDescriptor response_fields) {
if constexpr (RawMethod::matches<method>()) {
return GetMethodFor<method, RawMethod, type>(id);
- } else if constexpr (NanopbMethod::matches<method>()) {
+ } else if constexpr (NanopbMethod::matches<method, Request, Response>()) {
return GetMethodFor<method, NanopbMethod, type>(
id, request_fields, response_fields);
} else {
diff --git a/pw_rpc/public/pw_rpc/internal/method.h b/pw_rpc/public/pw_rpc/internal/method.h
index aae02df..fec137c 100644
--- a/pw_rpc/public/pw_rpc/internal/method.h
+++ b/pw_rpc/public/pw_rpc/internal/method.h
@@ -91,9 +91,14 @@
// Specializations must set kType to the MethodType.
// static constexpr MethodType kType = (method type);
- // Specializations for member function types must set Service as an alias to
- // the implemented service class.
+ // Specializations for member function types must set Service to an alias to
+ // for the implemented service class.
using Service = rpc::Service;
+
+ // Specializations may provide the C++ types of the requests and responses if
+ // relevant.
+ using Request = void;
+ using Response = void;
};
template <auto method>
diff --git a/pw_rpc/pw_rpc_private/method_impl_tester.h b/pw_rpc/pw_rpc_private/method_impl_tester.h
index 13ef880..8140d0b 100644
--- a/pw_rpc/pw_rpc_private/method_impl_tester.h
+++ b/pw_rpc/pw_rpc_private/method_impl_tester.h
@@ -45,22 +45,6 @@
//
template <typename MethodImpl, typename TestService, auto... extra_method_args>
struct MethodImplTester {
- // Test that the matches() function matches valid signatures.
- static_assert(MethodImpl::template matches<&TestService::Unary>());
- static_assert(MethodImpl::template matches<&TestService::ServerStreaming>());
- static_assert(MethodImpl::template matches<&TestService::StaticUnary>());
- static_assert(
- MethodImpl::template matches<&TestService::StaticServerStreaming>());
-
- // Test that the matches() function doesn't match invalid signatures.
- static_assert(!MethodImpl::template matches<&TestService::UnaryWrongArg>());
- static_assert(
- !MethodImpl::template matches<&TestService::StaticUnaryVoidReturn>());
- static_assert(
- !MethodImpl::template matches<&TestService::ServerStreamingBadReturn>());
- static_assert(!MethodImpl::template matches<
- &TestService::StaticServerStreamingMissingArg>());
-
// Test the MethodTraits::kType member.
static_assert(MethodTraits<decltype(&TestService::Unary)>::kType ==
MethodType::kUnary);
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index c888365..4127062 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -49,7 +49,9 @@
output.write_line(
f'{RPC_NAMESPACE}::internal::GetNanopbOrRawMethodFor<{impl_method}, '
- f'{method.type().cc_enum()}>(')
+ f'{method.type().cc_enum()}, '
+ f'{method.request_type().nanopb_name()}, '
+ f'{method.response_type().nanopb_name()}>(')
with output.indent(4):
output.write_line(f'0x{method_id:08x}, // Hash of "{method.name()}"')
output.write_line(f'{req_fields},')
diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc
index f3f0d4e..c8c9c44 100644
--- a/pw_rpc/raw/raw_method_test.cc
+++ b/pw_rpc/raw/raw_method_test.cc
@@ -63,6 +63,22 @@
}
};
+// Test that the matches() function matches valid signatures.
+static_assert(RawMethod::template matches<&TestRawService::Unary>());
+static_assert(RawMethod::template matches<&TestRawService::ServerStreaming>());
+static_assert(RawMethod::template matches<&TestRawService::StaticUnary>());
+static_assert(
+ RawMethod::template matches<&TestRawService::StaticServerStreaming>());
+
+// Test that the matches() function does not match the wrong method type.
+static_assert(!RawMethod::template matches<&TestRawService::UnaryWrongArg>());
+static_assert(
+ !RawMethod::template matches<&TestRawService::StaticUnaryVoidReturn>());
+static_assert(
+ !RawMethod::template matches<&TestRawService::ServerStreamingBadReturn>());
+static_assert(!RawMethod::template matches<
+ &TestRawService::StaticServerStreamingMissingArg>());
+
TEST(MethodImplTester, RawMethod) {
constexpr MethodImplTester<RawMethod, TestRawService> method_tester;
EXPECT_TRUE(method_tester.MethodImplIsValid());