pw_rpc: Easier callback_client RPC invocations

- Accept request proto arguments directly in callback-based RPC
  invocations through the request_args argument.
- Provide the request and response types directly on the method client
  as request and response.

Change-Id: I22df15c2aaf3390cf749ef8b5658be66461e2667
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/55502
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
diff --git a/pw_rpc/py/pw_rpc/callback_client/__init__.py b/pw_rpc/py/pw_rpc/callback_client/__init__.py
index 5bb91ce..6e7b75a 100644
--- a/pw_rpc/py/pw_rpc/callback_client/__init__.py
+++ b/pw_rpc/py/pw_rpc/callback_client/__init__.py
@@ -44,7 +44,7 @@
 means to wait indefinitely for a response.
 
 Asynchronous invocations immediately return a call object. Callbacks may be
-provided for the three types of RPC events:
+provided for the three RPC events:
 
   * on_next(call_object, response) - called for each response
   * on_completed(call_object, status) - called when the RPC completes
@@ -52,15 +52,21 @@
 
 If no callbacks are provided, the events are simply logged.
 
+Unary and client streaming RPCs are invoked asynchronously by calling invoke on
+the method object. invoke takes the callbacks. The request may be provided
+either as a constructed protobuf or as a dict of proto fields in the
+request_args parameter.
+
 .. code-block:: python
 
-  # Custom callbacks are called for each event.
-  call = client.channel(1).call.MyService.MyServerStreaming.invoke(
-      the_request, on_next_cb, on_completed_cb, on_error_cb):
+  # Pass the request as a protobuf and provide callbacks for all RPC events.
+  rpc = client.channel(1).call.MyService.MyServerStreaming
+  call = rpc.invoke(rpc.request(a=1), on_next_cb, on_completed_cb, on_error_cb)
 
-  # A callback is called for the response. Other events are logged.
+  # Create the request from the provided keyword args. Provide a callback for
+  # responses, but simply log the other RPC events.
   call = client.channel(1).call.MyServer.MyUnary.invoke(
-      the_request, lambda _, reply: process_this(reply))
+      request_args=dict(a=1, b=2), on_next=lambda _, reply: process_this(reply))
 
 For client and bidirectional streaming RPCs, requests are sent with the send
 method. The finish_and_wait method finishes the client stream. It optionally
@@ -69,7 +75,7 @@
 .. code-block:: python
 
   # Start the call using callbacks.
-  call = client.channel(1).call.MyServer.MyClientStream.invoke(on_error=err_cb)
+  call = client.channel(1).rpcs.MyServer.MyClientStream.invoke(on_error=err_cb)
 
   # Send a single client stream request.
   call.send(some_field=123)
diff --git a/pw_rpc/py/pw_rpc/callback_client/impl.py b/pw_rpc/py/pw_rpc/callback_client/impl.py
index 268f993..fc0e44f 100644
--- a/pw_rpc/py/pw_rpc/callback_client/impl.py
+++ b/pw_rpc/py/pw_rpc/callback_client/impl.py
@@ -16,7 +16,7 @@
 import inspect
 import logging
 import textwrap
-from typing import Callable, Iterable, Optional, Type
+from typing import Any, Callable, Dict, Iterable, Optional, Type
 
 from pw_status import Status
 from google.protobuf.message import Message
@@ -66,6 +66,16 @@
     def service(self) -> Service:
         return self._rpc.service
 
+    @property
+    def request(self) -> type:
+        """Returns the request proto class."""
+        return self.method.request_type
+
+    @property
+    def response(self) -> type:
+        """Returns the response proto class."""
+        return self.method.response_type
+
     def __repr__(self) -> str:
         return self.help()
 
@@ -161,10 +171,12 @@
                on_completed: OnCompletedCallback = None,
                on_error: OnErrorCallback = None,
                *,
+               request_args: Dict[str, Any] = None,
                timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryCall:
         """Invokes the unary RPC and returns a call object."""
-        return self._start_call(UnaryCall, request, timeout_s, on_next,
-                                on_completed, on_error)
+        return self._start_call(UnaryCall,
+                                self.method.get_request(request, request_args),
+                                timeout_s, on_next, on_completed, on_error)
 
 
 class _ServerStreamingMethodClient(_MethodClient):
@@ -175,11 +187,13 @@
             on_completed: OnCompletedCallback = None,
             on_error: OnErrorCallback = None,
             *,
+            request_args: Dict[str, Any] = None,
             timeout_s: OptionalTimeout = UseDefault.VALUE
     ) -> ServerStreamingCall:
         """Invokes the server streaming RPC and returns a call object."""
-        return self._start_call(ServerStreamingCall, request, timeout_s,
-                                on_next, on_completed, on_error)
+        return self._start_call(ServerStreamingCall,
+                                self.method.get_request(request, request_args),
+                                timeout_s, on_next, on_completed, on_error)
 
 
 class _ClientStreamingMethodClient(_MethodClient):
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index ab0c798..fdae732 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -214,13 +214,16 @@
         return self.Type.UNARY
 
     def get_request(self, proto: Optional[Message],
-                    proto_kwargs: Dict[str, Any]) -> Message:
+                    proto_kwargs: Optional[Dict[str, Any]]) -> Message:
         """Returns a request_type protobuf message.
 
         The client implementation may use this to support providing a request
         as either a message object or as keyword arguments for the message's
         fields (but not both).
         """
+        if proto_kwargs is None:
+            proto_kwargs = {}
+
         if proto and proto_kwargs:
             proto_str = repr(proto).strip() or "''"
             raise TypeError(
diff --git a/pw_rpc/py/tests/callback_client_test.py b/pw_rpc/py/tests/callback_client_test.py
index 7ee37ca..1345693 100755
--- a/pw_rpc/py/tests/callback_client_test.py
+++ b/pw_rpc/py/tests/callback_client_test.py
@@ -237,6 +237,14 @@
             rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s,
             100)
 
+    def test_rpc_provides_request_type(self) -> None:
+        self.assertIs(self._service.SomeUnary.request,
+                      self._service.SomeUnary.method.request_type)
+
+    def test_rpc_provides_response_type(self) -> None:
+        self.assertIs(self._service.SomeUnary.request,
+                      self._service.SomeUnary.method.request_type)
+
 
 class UnaryTest(_CallbackClientImplTestBase):
     """Tests for invoking a unary RPC."""
@@ -306,6 +314,11 @@
 
         callback.assert_not_called()
 
+    def test_nonblocking_with_request_args(self) -> None:
+        self.rpc.invoke(request_args=dict(magic_number=1138))
+        self.assertEqual(
+            self._sent_payload(self.rpc.request).magic_number, 1138)
+
     def test_blocking_timeout_as_argument(self) -> None:
         with self.assertRaises(callback_client.RpcTimeout):
             self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
@@ -419,6 +432,11 @@
             mock.call(call, Status.OK),
         ])
 
+    def test_nonblocking_with_request_args(self) -> None:
+        self.rpc.invoke(request_args=dict(magic_number=1138))
+        self.assertEqual(
+            self._sent_payload(self.rpc.request).magic_number, 1138)
+
     def test_blocking_timeout(self) -> None:
         with self.assertRaises(callback_client.RpcTimeout):
             self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
diff --git a/pw_rpc/py/tests/descriptors_test.py b/pw_rpc/py/tests/descriptors_test.py
index 516d5e3..9b3bf93 100644
--- a/pw_rpc/py/tests/descriptors_test.py
+++ b/pw_rpc/py/tests/descriptors_test.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
 #
 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
 # use this file except in compliance with the License. You may obtain a copy of
@@ -62,6 +62,10 @@
             self._method.get_request(self._method.request_type(),
                                      {'magic_number': 1})
 
+    def test_get_request_neither_message_nor_kwargs(self):
+        self.assertEqual(self._method.request_type(),
+                         self._method.get_request(None, None))
+
     def test_get_request_with_wrong_type(self):
         with self.assertRaisesRegex(TypeError, r'pw\.test1\.SomeMessage'):
             self._method.get_request('a str!', {})