pw_rpc: Easier callback_client RPC invocations
- Accept request proto arguments directly in callback-based RPC
invocations through the request_args argument.
- Provide the request and response types directly on the method client
as request and response.
Change-Id: I22df15c2aaf3390cf749ef8b5658be66461e2667
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/55502
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
diff --git a/pw_rpc/py/pw_rpc/callback_client/__init__.py b/pw_rpc/py/pw_rpc/callback_client/__init__.py
index 5bb91ce..6e7b75a 100644
--- a/pw_rpc/py/pw_rpc/callback_client/__init__.py
+++ b/pw_rpc/py/pw_rpc/callback_client/__init__.py
@@ -44,7 +44,7 @@
means to wait indefinitely for a response.
Asynchronous invocations immediately return a call object. Callbacks may be
-provided for the three types of RPC events:
+provided for the three RPC events:
* on_next(call_object, response) - called for each response
* on_completed(call_object, status) - called when the RPC completes
@@ -52,15 +52,21 @@
If no callbacks are provided, the events are simply logged.
+Unary and client streaming RPCs are invoked asynchronously by calling invoke on
+the method object. invoke takes the callbacks. The request may be provided
+either as a constructed protobuf or as a dict of proto fields in the
+request_args parameter.
+
.. code-block:: python
- # Custom callbacks are called for each event.
- call = client.channel(1).call.MyService.MyServerStreaming.invoke(
- the_request, on_next_cb, on_completed_cb, on_error_cb):
+ # Pass the request as a protobuf and provide callbacks for all RPC events.
+ rpc = client.channel(1).call.MyService.MyServerStreaming
+ call = rpc.invoke(rpc.request(a=1), on_next_cb, on_completed_cb, on_error_cb)
- # A callback is called for the response. Other events are logged.
+ # Create the request from the provided keyword args. Provide a callback for
+ # responses, but simply log the other RPC events.
call = client.channel(1).call.MyServer.MyUnary.invoke(
- the_request, lambda _, reply: process_this(reply))
+ request_args=dict(a=1, b=2), on_next=lambda _, reply: process_this(reply))
For client and bidirectional streaming RPCs, requests are sent with the send
method. The finish_and_wait method finishes the client stream. It optionally
@@ -69,7 +75,7 @@
.. code-block:: python
# Start the call using callbacks.
- call = client.channel(1).call.MyServer.MyClientStream.invoke(on_error=err_cb)
+ call = client.channel(1).rpcs.MyServer.MyClientStream.invoke(on_error=err_cb)
# Send a single client stream request.
call.send(some_field=123)
diff --git a/pw_rpc/py/pw_rpc/callback_client/impl.py b/pw_rpc/py/pw_rpc/callback_client/impl.py
index 268f993..fc0e44f 100644
--- a/pw_rpc/py/pw_rpc/callback_client/impl.py
+++ b/pw_rpc/py/pw_rpc/callback_client/impl.py
@@ -16,7 +16,7 @@
import inspect
import logging
import textwrap
-from typing import Callable, Iterable, Optional, Type
+from typing import Any, Callable, Dict, Iterable, Optional, Type
from pw_status import Status
from google.protobuf.message import Message
@@ -66,6 +66,16 @@
def service(self) -> Service:
return self._rpc.service
+ @property
+ def request(self) -> type:
+ """Returns the request proto class."""
+ return self.method.request_type
+
+ @property
+ def response(self) -> type:
+ """Returns the response proto class."""
+ return self.method.response_type
+
def __repr__(self) -> str:
return self.help()
@@ -161,10 +171,12 @@
on_completed: OnCompletedCallback = None,
on_error: OnErrorCallback = None,
*,
+ request_args: Dict[str, Any] = None,
timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryCall:
"""Invokes the unary RPC and returns a call object."""
- return self._start_call(UnaryCall, request, timeout_s, on_next,
- on_completed, on_error)
+ return self._start_call(UnaryCall,
+ self.method.get_request(request, request_args),
+ timeout_s, on_next, on_completed, on_error)
class _ServerStreamingMethodClient(_MethodClient):
@@ -175,11 +187,13 @@
on_completed: OnCompletedCallback = None,
on_error: OnErrorCallback = None,
*,
+ request_args: Dict[str, Any] = None,
timeout_s: OptionalTimeout = UseDefault.VALUE
) -> ServerStreamingCall:
"""Invokes the server streaming RPC and returns a call object."""
- return self._start_call(ServerStreamingCall, request, timeout_s,
- on_next, on_completed, on_error)
+ return self._start_call(ServerStreamingCall,
+ self.method.get_request(request, request_args),
+ timeout_s, on_next, on_completed, on_error)
class _ClientStreamingMethodClient(_MethodClient):
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index ab0c798..fdae732 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -214,13 +214,16 @@
return self.Type.UNARY
def get_request(self, proto: Optional[Message],
- proto_kwargs: Dict[str, Any]) -> Message:
+ proto_kwargs: Optional[Dict[str, Any]]) -> Message:
"""Returns a request_type protobuf message.
The client implementation may use this to support providing a request
as either a message object or as keyword arguments for the message's
fields (but not both).
"""
+ if proto_kwargs is None:
+ proto_kwargs = {}
+
if proto and proto_kwargs:
proto_str = repr(proto).strip() or "''"
raise TypeError(
diff --git a/pw_rpc/py/tests/callback_client_test.py b/pw_rpc/py/tests/callback_client_test.py
index 7ee37ca..1345693 100755
--- a/pw_rpc/py/tests/callback_client_test.py
+++ b/pw_rpc/py/tests/callback_client_test.py
@@ -237,6 +237,14 @@
rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s,
100)
+ def test_rpc_provides_request_type(self) -> None:
+ self.assertIs(self._service.SomeUnary.request,
+ self._service.SomeUnary.method.request_type)
+
+ def test_rpc_provides_response_type(self) -> None:
+ self.assertIs(self._service.SomeUnary.request,
+ self._service.SomeUnary.method.request_type)
+
class UnaryTest(_CallbackClientImplTestBase):
"""Tests for invoking a unary RPC."""
@@ -306,6 +314,11 @@
callback.assert_not_called()
+ def test_nonblocking_with_request_args(self) -> None:
+ self.rpc.invoke(request_args=dict(magic_number=1138))
+ self.assertEqual(
+ self._sent_payload(self.rpc.request).magic_number, 1138)
+
def test_blocking_timeout_as_argument(self) -> None:
with self.assertRaises(callback_client.RpcTimeout):
self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
@@ -419,6 +432,11 @@
mock.call(call, Status.OK),
])
+ def test_nonblocking_with_request_args(self) -> None:
+ self.rpc.invoke(request_args=dict(magic_number=1138))
+ self.assertEqual(
+ self._sent_payload(self.rpc.request).magic_number, 1138)
+
def test_blocking_timeout(self) -> None:
with self.assertRaises(callback_client.RpcTimeout):
self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
diff --git a/pw_rpc/py/tests/descriptors_test.py b/pw_rpc/py/tests/descriptors_test.py
index 516d5e3..9b3bf93 100644
--- a/pw_rpc/py/tests/descriptors_test.py
+++ b/pw_rpc/py/tests/descriptors_test.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
@@ -62,6 +62,10 @@
self._method.get_request(self._method.request_type(),
{'magic_number': 1})
+ def test_get_request_neither_message_nor_kwargs(self):
+ self.assertEqual(self._method.request_type(),
+ self._method.get_request(None, None))
+
def test_get_request_with_wrong_type(self):
with self.assertRaisesRegex(TypeError, r'pw\.test1\.SomeMessage'):
self._method.get_request('a str!', {})