pw_rpc: Provide detailed RPC help in RPC client
Create detailed help strings for methods in pw_rpc.callback_client.
Change-Id: I257ee56c25970923c708d58a3315aeb100473197
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/26980
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index 6995ef7..732d424 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -43,11 +43,12 @@
import inspect
import logging
import queue
+import textwrap
from typing import Any, Callable, Optional, Tuple
from pw_status import Status
-from pw_rpc import client
+from pw_rpc import client, descriptors
from pw_rpc.descriptors import Channel, Method, Service
_LOG = logging.getLogger(__name__)
@@ -94,7 +95,28 @@
return _AsyncCall(self._rpcs, self._rpc)
def __repr__(self) -> str:
- return repr(self.method)
+ return self.help()
+
+ def __call__(self):
+ raise NotImplementedError('Implemented by derived classes')
+
+ def help(self) -> str:
+ """Returns a help message about this RPC."""
+ function_call = self.method.request_type.DESCRIPTOR.full_name + '('
+
+ docstring = inspect.getdoc(self.__call__)
+ assert docstring is not None
+
+ annotation = inspect.Signature.from_callable(self).return_annotation
+ if isinstance(annotation, type):
+ annotation = annotation.__name__
+
+ arg_sep = f',\n{" " * len(function_call)}'
+ return (
+ f'{function_call}'
+ f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
+ f'\n\n{textwrap.indent(docstring, " ")}\n\n'
+ f' Returns {annotation}.')
class _AsyncCall:
@@ -116,9 +138,10 @@
self.cancel()
-class _StreamingResponses:
+class StreamingResponses:
"""Used to iterate over a queue.SimpleQueue."""
- def __init__(self, responses: queue.SimpleQueue):
+ def __init__(self, method: Method, responses: queue.SimpleQueue):
+ self.method = method
self._queue = responses
self.status: Optional[Status] = None
@@ -133,6 +156,9 @@
def __iter__(self):
return self.get()
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}({self.method})'
+
def _method_client_docstring(method: Method) -> str:
return f'''\
@@ -147,8 +173,8 @@
return f'''\
Invokes the {method.full_name} {method.type.sentence_name()} RPC.
-This function accepts either a request protobuf as a single positional argument
-or the request fields as keyword arguments.
+This function accepts either the request protobuf fields as keyword arguments or
+a request protobuf as a positional argument.
'''
@@ -197,12 +223,12 @@
"""Creates an object used to call a server streaming method."""
def call(self,
_rpc_request_proto=None,
- **request_fields) -> _StreamingResponses:
+ **request_fields) -> StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
self.reinvoke(
lambda _, status, payload: responses.put((status, payload)),
_rpc_request_proto, **request_fields)
- return _StreamingResponses(responses)
+ return StreamingResponses(method, responses)
_update_function_signature(method, call)
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index 7ac7af1..79a1943 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -104,13 +104,32 @@
def _field_type_annotation(field: FieldDescriptor):
"""Creates a field type annotation to use in the help message only."""
- annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
+ if field.type == FieldDescriptor.TYPE_MESSAGE:
+ annotation = message_factory.MessageFactory(
+ field.message_type.file.pool).GetPrototype(field.message_type)
+ else:
+ annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
+
if field.label == FieldDescriptor.LABEL_REPEATED:
return Iterable[annotation] # type: ignore[valid-type]
return annotation
+def field_help(proto_message) -> Iterator[str]:
+ """Yields argument strings for proto fields for use in a help message."""
+ for field in proto_message.DESCRIPTOR.fields:
+ if field.type == FieldDescriptor.TYPE_ENUM:
+ value = field.enum_type.values_by_number[field.default_value].name
+ type_name = field.enum_type.full_name
+ value = f'{type_name.rsplit(".", 1)[0]}.{value}'
+ else:
+ type_name = _PROTO_FIELD_TYPES[field.type].__name__
+ value = repr(field.default_value)
+
+ yield f'{field.name}: {type_name} = {value}'
+
+
@dataclass(frozen=True, eq=False)
class Method:
"""Describes a method in a service."""