pw_rpc: Support testing services with mixed protobuf libraries

- Replace the public internal method lookup generated functions with a
  private array of method IDs. Create a MethodLookup class that is
  friended by generated services to access this array.
- Use the generated method ID array to support method lookup in services
  with multiple different protobuf implementations (Nanopb and raw).
- Ensure that the correct method implementation is selected when looking
  up a method at compile time. This prevents accidentally accessing a
  Nanopb method as a raw method or vice versa in a test method context.

Change-Id: I0ea334c02ae7c29e50099aa5bcfa975df625d15e
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/25740
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
diff --git a/pw_rpc/BUILD b/pw_rpc/BUILD
index fe34d2d..cdd1084 100644
--- a/pw_rpc/BUILD
+++ b/pw_rpc/BUILD
@@ -45,6 +45,7 @@
         "public/pw_rpc/internal/call.h",
         "public/pw_rpc/internal/hash.h",
         "public/pw_rpc/internal/method.h",
+        "public/pw_rpc/internal/method_lookup.h",
         "public/pw_rpc/internal/method_union.h",
         "public/pw_rpc/internal/server.h",
         "server.cc",
@@ -103,6 +104,7 @@
     srcs = [
         "nanopb/codegen_test.cc",
         "nanopb/echo_service_test.cc",
+        "nanopb/method_lookup_test.cc",
         "nanopb/nanopb_client_call.cc",
         "nanopb/nanopb_client_call_test.cc",
         "nanopb/nanopb_common.cc",
@@ -116,9 +118,6 @@
         "nanopb/public/pw_rpc/nanopb_client_call.h",
         "nanopb/public/pw_rpc/nanopb_test_method_context.h",
         "nanopb/pw_rpc_nanopb_private/internal_test_utils.h",
-        "nanopb/test.pb.c",
-        "nanopb/test.pb.h",
-        "nanopb/test_rpc.pb.h",
     ],
 )
 
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 1dee6ca..4b6d0ea 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -41,6 +41,7 @@
     "public/pw_rpc/internal/call.h",
     "public/pw_rpc/internal/hash.h",
     "public/pw_rpc/internal/method.h",
+    "public/pw_rpc/internal/method_lookup.h",
     "public/pw_rpc/internal/method_union.h",
     "public/pw_rpc/internal/server.h",
     "server.cc",
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index b3ee46a..baf2805 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -98,6 +98,7 @@
     ":client_call_test",
     ":codegen_test",
     ":echo_service_test",
+    ":method_lookup_test",
     ":nanopb_method_test",
     ":nanopb_method_union_test",
   ]
@@ -139,6 +140,18 @@
   enable_if = dir_pw_third_party_nanopb != ""
 }
 
+pw_test("method_lookup_test") {
+  deps = [
+    ":method",
+    ":test_method_context",
+    "..:test_protos.nanopb_rpc",
+    "..:test_utils",
+    "../raw:test_method_context",
+  ]
+  sources = [ "method_lookup_test.cc" ]
+  enable_if = dir_pw_third_party_nanopb != ""
+}
+
 pw_test("nanopb_method_union_test") {
   deps = [
     ":internal_test_utils",
diff --git a/pw_rpc/nanopb/method_lookup_test.cc b/pw_rpc/nanopb/method_lookup_test.cc
new file mode 100644
index 0000000..02137c8
--- /dev/null
+++ b/pw_rpc/nanopb/method_lookup_test.cc
@@ -0,0 +1,81 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "gtest/gtest.h"
+#include "pw_rpc/nanopb_test_method_context.h"
+#include "pw_rpc/raw_test_method_context.h"
+#include "pw_rpc_test_protos/test.rpc.pb.h"
+
+namespace pw::rpc {
+namespace {
+
+class MixedService1 : public test::generated::TestService<MixedService1> {
+ public:
+  StatusWithSize TestRpc(ServerContext&, ConstByteSpan, ByteSpan) {
+    return StatusWithSize(123);
+  }
+
+  void TestStreamRpc(ServerContext&,
+                     const pw_rpc_test_TestRequest&,
+                     ServerWriter<pw_rpc_test_TestStreamResponse>&) {
+    called_streaming_method = true;
+  }
+
+  bool called_streaming_method = false;
+};
+
+class MixedService2 : public test::generated::TestService<MixedService2> {
+ public:
+  Status TestRpc(ServerContext&,
+                 const pw_rpc_test_TestRequest&,
+                 pw_rpc_test_TestResponse&) {
+    return Status::Unauthenticated();
+  }
+
+  void TestStreamRpc(ServerContext&, ConstByteSpan, RawServerWriter&) {
+    called_streaming_method = true;
+  }
+
+  bool called_streaming_method = false;
+};
+
+TEST(MixedService1, CallRawMethod) {
+  PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestRpc) context;
+  StatusWithSize sws = context.call({});
+  EXPECT_TRUE(sws.ok());
+  EXPECT_EQ(123u, sws.size());
+}
+
+TEST(MixedService1, CallNanopbMethod) {
+  PW_NANOPB_TEST_METHOD_CONTEXT(MixedService1, TestStreamRpc) context;
+  ASSERT_FALSE(context.service().called_streaming_method);
+  context.call({});
+  EXPECT_TRUE(context.service().called_streaming_method);
+}
+
+TEST(MixedService2, CallNanopbMethod) {
+  PW_NANOPB_TEST_METHOD_CONTEXT(MixedService2, TestRpc) context;
+  Status status = context.call({});
+  EXPECT_EQ(Status::Unauthenticated(), status);
+}
+
+TEST(MixedService2, CallRawMethod) {
+  PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestStreamRpc) context;
+  ASSERT_FALSE(context.service().called_streaming_method);
+  context.call({});
+  EXPECT_TRUE(context.service().called_streaming_method);
+}
+
+}  // namespace
+}  // namespace pw::rpc
diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
index 0ee76fe..08572e6 100644
--- a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
+++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h
@@ -21,6 +21,7 @@
 #include "pw_preprocessor/arguments.h"
 #include "pw_rpc/channel.h"
 #include "pw_rpc/internal/hash.h"
+#include "pw_rpc/internal/method_lookup.h"
 #include "pw_rpc/internal/nanopb_method.h"
 #include "pw_rpc/internal/packet.h"
 #include "pw_rpc/internal/server.h"
@@ -121,15 +122,6 @@
   Status last_status_;
 };
 
-template <typename Service, uint32_t method_id>
-constexpr const internal::NanopbMethod& GetNanopbMethod() {
-  constexpr const internal::NanopbMethod* nanopb_method =
-      GeneratedService<Service>::NanopbMethodFor(method_id);
-  static_assert(nanopb_method != nullptr,
-                "The selected function is not an RPC service method");
-  return *nanopb_method;
-}
-
 // Collects everything needed to invoke a particular RPC.
 template <typename Service,
           auto method,
@@ -142,14 +134,16 @@
 
   template <typename... Args>
   InvocationContext(Args&&... args)
-      : output(GetNanopbMethod<Service, method_id>(), responses, buffer),
+      : output(MethodLookup::GetNanopbMethod<Service, method_id>(),
+               responses,
+               buffer),
         channel(Channel::Create<123>(&output)),
         server(std::span(&channel, 1)),
         service(std::forward<Args>(args)...),
         call(static_cast<internal::Server&>(server),
              static_cast<internal::Channel&>(channel),
              service,
-             GetNanopbMethod<Service, method_id>()) {}
+             MethodLookup::GetNanopbMethod<Service, method_id>()) {}
 
   MessageOutput<Response> output;
 
@@ -176,6 +170,8 @@
   template <typename... Args>
   UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
 
+  Service& service() { return ctx_.service; }
+
   // Invokes the RPC with the provided request. Returns the status.
   Status call(const Request& request) {
     ctx_.output.clear();
@@ -210,6 +206,8 @@
   template <typename... Args>
   ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
 
+  Service& service() { return ctx_.service; }
+
   // Invokes the RPC with the provided request.
   void call(const Request& request) {
     ctx_.output.clear();
diff --git a/pw_rpc/public/pw_rpc/internal/method_lookup.h b/pw_rpc/public/pw_rpc/internal/method_lookup.h
new file mode 100644
index 0000000..cd7ddcb
--- /dev/null
+++ b/pw_rpc/public/pw_rpc/internal/method_lookup.h
@@ -0,0 +1,64 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include "pw_rpc/internal/method.h"
+
+namespace pw::rpc::internal {
+
+// Gets a Method object from a generated RPC service class. Getter functions are
+// provided for each supported method implementation. The
+//
+// To ensure the MethodUnion actually holds the requested method type, the
+// method ID is accessed in a static_assert. It is invalid to access an unset
+// union member in a constant expression, so this results in a compiler error.
+class MethodLookup {
+ public:
+  template <typename Service, uint32_t method_id>
+  static constexpr const auto& GetRawMethod() {
+    const auto& method = GetMethodUnion<Service, method_id>().raw_method();
+    static_assert(method.id() == method_id, "Incorrect method implementation");
+    return method;
+  }
+
+  template <typename Service, uint32_t method_id>
+  static constexpr const auto& GetNanopbMethod() {
+    const auto& method = GetMethodUnion<Service, method_id>().nanopb_method();
+    static_assert(method.id() == method_id, "Incorrect method implementation");
+    return method;
+  }
+
+ private:
+  template <typename Service, uint32_t method_id>
+  static constexpr const auto& GetMethodUnion() {
+    constexpr auto method = GetMethodUnionPointer<Service>(method_id);
+    static_assert(method != nullptr,
+                  "The selected function is not an RPC service method");
+    return *method;
+  }
+
+  template <typename Service>
+  static constexpr
+      typename decltype(GeneratedService<Service>::kMethods)::const_pointer
+      GetMethodUnionPointer(uint32_t method_id) {
+    for (size_t i = 0; i < GeneratedService<Service>::kMethodIds.size(); ++i) {
+      if (GeneratedService<Service>::kMethodIds[i] == method_id) {
+        return &GeneratedService<Service>::kMethods[i];
+      }
+    }
+    return nullptr;
+  }
+};
+
+}  // namespace pw::rpc::internal
diff --git a/pw_rpc/py/BUILD.gn b/pw_rpc/py/BUILD.gn
index 56ca19d..03f7b29 100644
--- a/pw_rpc/py/BUILD.gn
+++ b/pw_rpc/py/BUILD.gn
@@ -22,6 +22,7 @@
     "pw_rpc/__init__.py",
     "pw_rpc/callback_client.py",
     "pw_rpc/client.py",
+    "pw_rpc/codegen.py",
     "pw_rpc/codegen_nanopb.py",
     "pw_rpc/codegen_raw.py",
     "pw_rpc/descriptors.py",
diff --git a/pw_rpc/py/pw_rpc/codegen.py b/pw_rpc/py/pw_rpc/codegen.py
new file mode 100644
index 0000000..7dd5451
--- /dev/null
+++ b/pw_rpc/py/pw_rpc/codegen.py
@@ -0,0 +1,33 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Common RPC codegen utilities."""
+
+from pw_protobuf.output_file import OutputFile
+from pw_protobuf.proto_tree import ProtoService
+
+import pw_rpc.ids
+
+
+def method_lookup_table(service: ProtoService, output: OutputFile) -> None:
+    """Generates array of method IDs for looking up methods at compile time."""
+    output.write_line('static constexpr std::array<uint32_t, '
+                      f'{len(service.methods())}> kMethodIds = {{')
+
+    with output.indent(4):
+        for method in service.methods():
+            method_id = pw_rpc.ids.calculate(method.name())
+            output.write_line(
+                f'0x{method_id:08x},  // Hash of {method.name()}')
+
+    output.write_line('};\n')
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 0e33498..2d0bd7d 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -21,6 +21,7 @@
 from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
 from pw_protobuf.proto_tree import build_node_tree
 import pw_rpc.ids
+from pw_rpc import codegen
 
 PLUGIN_NAME = 'pw_rpc_codegen'
 PLUGIN_VERSION = '0.1.0'
@@ -61,29 +62,6 @@
         output.write_line(f'{res_fields}),')
 
 
-def _generate_method_lookup_function(output: OutputFile):
-    """Generates a function that gets a Method object from its ID."""
-    nanopb_method = f'{RPC_NAMESPACE}::internal::NanopbMethod'
-
-    output.write_line(
-        f'static constexpr const {nanopb_method}* NanopbMethodFor(')
-    output.write_line('    uint32_t id) {')
-
-    with output.indent():
-        output.write_line('for (auto& method : kMethods) {')
-        with output.indent():
-            output.write_line('if (method.nanopb_method().id() == id) {')
-            output.write_line(
-                f'  return &static_cast<const {nanopb_method}&>(')
-            output.write_line('    method.nanopb_method());')
-            output.write_line('}')
-        output.write_line('}')
-
-        output.write_line('return nullptr;')
-
-    output.write_line('}')
-
-
 def _generate_code_for_service(service: ProtoService, root: ProtoNode,
                                output: OutputFile) -> None:
     """Generates a C++ derived class for a nanopb RPC service."""
@@ -123,13 +101,11 @@
         output.write_line(
             'constexpr void _PwRpcInternalGeneratedBase() const {}')
 
-        output.write_line()
-        _generate_method_lookup_function(output)
-
     service_name_hash = pw_rpc.ids.calculate(service.proto_path())
     output.write_line('\n private:')
 
     with output.indent():
+        output.write_line('friend class ::pw::rpc::internal::MethodLookup;\n')
         output.write_line(f'// Hash of "{service.proto_path()}".')
         output.write_line(
             f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
@@ -146,7 +122,10 @@
             for method in service.methods():
                 _generate_method_descriptor(method, output)
 
-        output.write_line('};')
+        output.write_line('};\n')
+
+        # Generate the method lookup table
+        codegen.method_lookup_table(service, output)
 
     output.write_line('};')
 
@@ -242,6 +221,7 @@
     output.write_line('#include <cstddef>')
     output.write_line('#include <cstdint>')
     output.write_line('#include <type_traits>\n')
+    output.write_line('#include "pw_rpc/internal/method_lookup.h"')
     output.write_line('#include "pw_rpc/internal/nanopb_method_union.h"')
     output.write_line('#include "pw_rpc/nanopb_client_call.h"')
     output.write_line('#include "pw_rpc/server_context.h"')
diff --git a/pw_rpc/py/pw_rpc/codegen_raw.py b/pw_rpc/py/pw_rpc/codegen_raw.py
index 6a271cf..9b25bfe 100644
--- a/pw_rpc/py/pw_rpc/codegen_raw.py
+++ b/pw_rpc/py/pw_rpc/codegen_raw.py
@@ -21,6 +21,7 @@
 from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
 from pw_protobuf.proto_tree import build_node_tree
 import pw_rpc.ids
+from pw_rpc import codegen
 
 PLUGIN_NAME = 'pw_rpc_codegen'
 PLUGIN_VERSION = '0.1.0'
@@ -49,27 +50,6 @@
     output.write_line(f'    0x{method_id:08x}),  // Hash of "{method.name()}"')
 
 
-def _generate_method_lookup_function(output: OutputFile):
-    """Generates a function that gets a Method object from its ID."""
-    raw_method = f'{RPC_NAMESPACE}::internal::RawMethod'
-
-    output.write_line(f'static constexpr const {raw_method}* RawMethodFor(')
-    output.write_line('    uint32_t id) {')
-
-    with output.indent():
-        output.write_line('for (auto& method : kMethods) {')
-        with output.indent():
-            output.write_line('if (method.raw_method().id() == id) {')
-            output.write_line(f'  return &static_cast<const {raw_method}&>(')
-            output.write_line('    method.raw_method());')
-            output.write_line('}')
-        output.write_line('}')
-
-        output.write_line('return nullptr;')
-
-    output.write_line('}')
-
-
 def _generate_code_for_service(service: ProtoService, root: ProtoNode,
                                output: OutputFile) -> None:
     """Generates a C++ base class for a raw RPC service."""
@@ -106,13 +86,11 @@
         output.write_line(
             'constexpr void _PwRpcInternalGeneratedBase() const {}')
 
-        output.write_line()
-        _generate_method_lookup_function(output)
-
     service_name_hash = pw_rpc.ids.calculate(service.proto_path())
     output.write_line('\n private:')
 
     with output.indent():
+        output.write_line('friend class ::pw::rpc::internal::MethodLookup;\n')
         output.write_line(f'// Hash of "{service.proto_path()}".')
         output.write_line(
             f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
@@ -131,6 +109,9 @@
 
         output.write_line('};')
 
+        # Generate the method lookup table
+        codegen.method_lookup_table(service, output)
+
     output.write_line('};')
 
 
@@ -147,6 +128,7 @@
     output.write_line('#include <cstddef>')
     output.write_line('#include <cstdint>')
     output.write_line('#include <type_traits>\n')
+    output.write_line('#include "pw_rpc/internal/method_lookup.h"')
     output.write_line('#include "pw_rpc/internal/raw_method_union.h"')
     output.write_line('#include "pw_rpc/server_context.h"')
     output.write_line('#include "pw_rpc/service.h"\n')
diff --git a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
index 371e0f5..5e13a95 100644
--- a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
+++ b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h
@@ -20,6 +20,7 @@
 #include "pw_containers/vector.h"
 #include "pw_rpc/channel.h"
 #include "pw_rpc/internal/hash.h"
+#include "pw_rpc/internal/method_lookup.h"
 #include "pw_rpc/internal/packet.h"
 #include "pw_rpc/internal/raw_method.h"
 #include "pw_rpc/internal/server.h"
@@ -131,15 +132,6 @@
   Status last_status_;
 };
 
-template <typename Service, uint32_t method_id>
-constexpr const internal::RawMethod& GetRawMethod() {
-  constexpr const internal::RawMethod* raw_method =
-      GeneratedService<Service>::RawMethodFor(method_id);
-  static_assert(raw_method != nullptr,
-                "The selected function is not an RPC service method");
-  return *raw_method;
-}
-
 // Collects everything needed to invoke a particular RPC.
 template <typename Service,
           uint32_t method_id,
@@ -155,7 +147,7 @@
         call(static_cast<internal::Server&>(server),
              static_cast<internal::Channel&>(channel),
              service,
-             GetRawMethod<Service, method_id>()) {}
+             MethodLookup::GetRawMethod<Service, method_id>()) {}
 
   using ResponseBuffer = std::array<std::byte, output_size>;