pw_rpc: Check Nanopb request/response types

Ensure that Nanopb method implementations use the correct request and
response types. Using the wrong struct type will be a compilation error.

Change-Id: I9d300a108d76d014cb7699a72b7526b5b3e2d82e
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/33701
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc
index a9055bb..ca2653c 100644
--- a/pw_rpc/nanopb/nanopb_method_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_test.cc
@@ -61,6 +61,64 @@
                                               ServerWriter<FakePb>&) {}
 };
 
+// Test that the matches() function matches valid signatures.
+static_assert(NanopbMethod::template matches<&TestNanopbService::Unary,
+                                             FakePb,
+                                             FakePb>());
+static_assert(
+    NanopbMethod::template matches<&TestNanopbService::ServerStreaming,
+                                   FakePb,
+                                   FakePb>());
+static_assert(NanopbMethod::template matches<&TestNanopbService::StaticUnary,
+                                             FakePb,
+                                             FakePb>());
+static_assert(
+    NanopbMethod::template matches<&TestNanopbService::StaticServerStreaming,
+                                   FakePb,
+                                   FakePb>());
+
+// Test that the matches() function does not match the wrong method type.
+static_assert(!NanopbMethod::template matches<&TestNanopbService::UnaryWrongArg,
+                                              FakePb,
+                                              FakePb>());
+static_assert(
+    !NanopbMethod::template matches<&TestNanopbService::StaticUnaryVoidReturn,
+                                    FakePb,
+                                    FakePb>());
+static_assert(!NanopbMethod::template matches<
+              &TestNanopbService::ServerStreamingBadReturn,
+              FakePb,
+              FakePb>());
+static_assert(!NanopbMethod::template matches<
+              &TestNanopbService::StaticServerStreamingMissingArg,
+              FakePb,
+              FakePb>());
+
+struct WrongPb;
+
+// Test matches() rejects incorrect request/response types.
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+                                              WrongPb,
+                                              FakePb>());
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+                                              FakePb,
+                                              WrongPb>());
+static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary,
+                                              WrongPb,
+                                              WrongPb>());
+static_assert(
+    !NanopbMethod::template matches<&TestNanopbService::ServerStreaming,
+                                    WrongPb,
+                                    FakePb>());
+
+static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticUnary,
+                                              FakePb,
+                                              WrongPb>());
+static_assert(
+    !NanopbMethod::template matches<&TestNanopbService::StaticServerStreaming,
+                                    FakePb,
+                                    WrongPb>());
+
 TEST(MethodImplTester, NanopbMethod) {
   constexpr MethodImplTester<NanopbMethod, TestNanopbService, nullptr, nullptr>
       method_tester;
diff --git a/pw_rpc/nanopb/nanopb_method_union_test.cc b/pw_rpc/nanopb/nanopb_method_union_test.cc
index ab754e4..d1bfd7f 100644
--- a/pw_rpc/nanopb/nanopb_method_union_test.cc
+++ b/pw_rpc/nanopb/nanopb_method_union_test.cc
@@ -32,15 +32,25 @@
   constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {}
 
   static constexpr std::array<NanopbMethodUnion, 4> kMethods = {
-      GetNanopbOrRawMethodFor<&Implementation::DoNothing, MethodType::kUnary>(
+      GetNanopbOrRawMethodFor<&Implementation::DoNothing,
+                              MethodType::kUnary,
+                              pw_rpc_test_Empty,
+                              pw_rpc_test_Empty>(
           10u, pw_rpc_test_Empty_fields, pw_rpc_test_Empty_fields),
       GetNanopbOrRawMethodFor<&Implementation::RawStream,
-                              MethodType::kServerStreaming>(
+                              MethodType::kServerStreaming,
+                              pw_rpc_test_TestRequest,
+                              pw_rpc_test_TestResponse>(
           11u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
-      GetNanopbOrRawMethodFor<&Implementation::AddFive, MethodType::kUnary>(
+      GetNanopbOrRawMethodFor<&Implementation::AddFive,
+                              MethodType::kUnary,
+                              pw_rpc_test_TestRequest,
+                              pw_rpc_test_TestResponse>(
           12u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
       GetNanopbOrRawMethodFor<&Implementation::StartStream,
-                              MethodType::kServerStreaming>(
+                              MethodType::kServerStreaming,
+                              pw_rpc_test_TestRequest,
+                              pw_rpc_test_TestResponse>(
           13u, pw_rpc_test_TestRequest_fields, pw_rpc_test_TestResponse_fields),
   };
 };
diff --git a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
index 2a20600..dc4d45b 100644
--- a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
+++ b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h
@@ -119,9 +119,11 @@
 // structs.
 class NanopbMethod : public Method {
  public:
-  template <auto method>
+  template <auto method, typename RequestType, typename ResponseType>
   static constexpr bool matches() {
-    return std::is_same_v<MethodImplementation<method>, NanopbMethod>;
+    return std::is_same_v<MethodImplementation<method>, NanopbMethod> &&
+           std::is_same_v<RequestType, Request<method>> &&
+           std::is_same_v<ResponseType, Response<method>>;
   }
 
   // Creates a NanopbMethod for a unary RPC.
diff --git a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
index d73813b..da4ec30 100644
--- a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
+++ b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method_union.h
@@ -42,14 +42,14 @@
 
 // Returns either a raw or nanopb method object, depending on the implemented
 // function's signature.
-template <auto method, MethodType type>
+template <auto method, MethodType type, typename Request, typename Response>
 constexpr auto GetNanopbOrRawMethodFor(
     uint32_t id,
     [[maybe_unused]] NanopbMessageDescriptor request_fields,
     [[maybe_unused]] NanopbMessageDescriptor response_fields) {
   if constexpr (RawMethod::matches<method>()) {
     return GetMethodFor<method, RawMethod, type>(id);
-  } else if constexpr (NanopbMethod::matches<method>()) {
+  } else if constexpr (NanopbMethod::matches<method, Request, Response>()) {
     return GetMethodFor<method, NanopbMethod, type>(
         id, request_fields, response_fields);
   } else {
diff --git a/pw_rpc/public/pw_rpc/internal/method.h b/pw_rpc/public/pw_rpc/internal/method.h
index aae02df..fec137c 100644
--- a/pw_rpc/public/pw_rpc/internal/method.h
+++ b/pw_rpc/public/pw_rpc/internal/method.h
@@ -91,9 +91,14 @@
   // Specializations must set kType to the MethodType.
   // static constexpr MethodType kType = (method type);
 
-  // Specializations for member function types must set Service as an alias to
-  // the implemented service class.
+  // Specializations for member function types must set Service to an alias to
+  // for the implemented service class.
   using Service = rpc::Service;
+
+  // Specializations may provide the C++ types of the requests and responses if
+  // relevant.
+  using Request = void;
+  using Response = void;
 };
 
 template <auto method>
diff --git a/pw_rpc/pw_rpc_private/method_impl_tester.h b/pw_rpc/pw_rpc_private/method_impl_tester.h
index 13ef880..8140d0b 100644
--- a/pw_rpc/pw_rpc_private/method_impl_tester.h
+++ b/pw_rpc/pw_rpc_private/method_impl_tester.h
@@ -45,22 +45,6 @@
 //
 template <typename MethodImpl, typename TestService, auto... extra_method_args>
 struct MethodImplTester {
-  // Test that the matches() function matches valid signatures.
-  static_assert(MethodImpl::template matches<&TestService::Unary>());
-  static_assert(MethodImpl::template matches<&TestService::ServerStreaming>());
-  static_assert(MethodImpl::template matches<&TestService::StaticUnary>());
-  static_assert(
-      MethodImpl::template matches<&TestService::StaticServerStreaming>());
-
-  // Test that the matches() function doesn't match invalid signatures.
-  static_assert(!MethodImpl::template matches<&TestService::UnaryWrongArg>());
-  static_assert(
-      !MethodImpl::template matches<&TestService::StaticUnaryVoidReturn>());
-  static_assert(
-      !MethodImpl::template matches<&TestService::ServerStreamingBadReturn>());
-  static_assert(!MethodImpl::template matches<
-                &TestService::StaticServerStreamingMissingArg>());
-
   // Test the MethodTraits::kType member.
   static_assert(MethodTraits<decltype(&TestService::Unary)>::kType ==
                 MethodType::kUnary);
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index c888365..4127062 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -49,7 +49,9 @@
 
     output.write_line(
         f'{RPC_NAMESPACE}::internal::GetNanopbOrRawMethodFor<{impl_method}, '
-        f'{method.type().cc_enum()}>(')
+        f'{method.type().cc_enum()}, '
+        f'{method.request_type().nanopb_name()}, '
+        f'{method.response_type().nanopb_name()}>(')
     with output.indent(4):
         output.write_line(f'0x{method_id:08x},  // Hash of "{method.name()}"')
         output.write_line(f'{req_fields},')
diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc
index f3f0d4e..c8c9c44 100644
--- a/pw_rpc/raw/raw_method_test.cc
+++ b/pw_rpc/raw/raw_method_test.cc
@@ -63,6 +63,22 @@
   }
 };
 
+// Test that the matches() function matches valid signatures.
+static_assert(RawMethod::template matches<&TestRawService::Unary>());
+static_assert(RawMethod::template matches<&TestRawService::ServerStreaming>());
+static_assert(RawMethod::template matches<&TestRawService::StaticUnary>());
+static_assert(
+    RawMethod::template matches<&TestRawService::StaticServerStreaming>());
+
+// Test that the matches() function does not match the wrong method type.
+static_assert(!RawMethod::template matches<&TestRawService::UnaryWrongArg>());
+static_assert(
+    !RawMethod::template matches<&TestRawService::StaticUnaryVoidReturn>());
+static_assert(
+    !RawMethod::template matches<&TestRawService::ServerStreamingBadReturn>());
+static_assert(!RawMethod::template matches<
+              &TestRawService::StaticServerStreamingMissingArg>());
+
 TEST(MethodImplTester, RawMethod) {
   constexpr MethodImplTester<RawMethod, TestRawService> method_tester;
   EXPECT_TRUE(method_tester.MethodImplIsValid());