pw_rpc: Python client usability improvements
- Update __repr__ to make it easier to navigate RPCs interactively.
- Make proto package members attributes instead of relying on
__getattr__ so tab completion works.
Change-Id: Ief573a139203f51b420f56e67c6c36ddb1b262d6
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/18323
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
diff --git a/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
index 5f22802..4b7f75a 100644
--- a/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
+++ b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
@@ -148,14 +148,21 @@
self._items: List[T] = []
self._package = package
- def __getattr__(self, attr: str):
- """Descends into subpackages or access proto entities in a package."""
- if attr in self._packages:
- return self._packages[attr]
+ def _add_package(self, subpackage: str, package: '_NestedPackage') -> None:
+ self._packages[subpackage] = package
+ setattr(self, subpackage, package)
- for module in self._items:
- if hasattr(module, attr):
- return getattr(module, attr)
+ def _add_item(self, item) -> None:
+ self._items.append(item)
+ for attr, value in vars(item).items():
+ if not attr.startswith('_'):
+ setattr(self, attr, value)
+
+ def __getattr__(self, attr: str):
+ # Fall back to item attributes, which includes private attributes.
+ for item in self._items:
+ if hasattr(item, attr):
+ return getattr(item, attr)
raise AttributeError(
f'Proto package "{self._package}" does not contain "{attr}"')
@@ -164,7 +171,19 @@
return iter(self._packages.values())
def __repr__(self) -> str:
- return f'_NestedPackage({self._package!r})'
+ msg = [f'ProtoPackage({self._package!r}']
+
+ public_members = [
+ i for i in vars(self)
+ if i not in self._packages and not i.startswith('_')
+ ]
+ if public_members:
+ msg.append(f'members={str(public_members)}')
+
+ if self._packages:
+ msg.append(f'subpackages={str(list(self._packages))}')
+
+ return ', '.join(msg) + ')'
def __str__(self) -> str:
return self._package
@@ -196,12 +215,12 @@
# pylint: disable=protected-access
for i, subpackage in enumerate(subpackages, 1):
if subpackage not in entry._packages:
- entry._packages[subpackage] = _NestedPackage('.'.join(
- subpackages[:i]))
+ entry._add_package(subpackage,
+ _NestedPackage('.'.join(subpackages[:i])))
entry = entry._packages[subpackage]
- entry._items.append(item)
+ entry._add_item(item)
# pylint: enable=protected-access
return packages
diff --git a/pw_rpc/py/client_test.py b/pw_rpc/py/client_test.py
index 90dc402..cb78992 100755
--- a/pw_rpc/py/client_test.py
+++ b/pw_rpc/py/client_test.py
@@ -114,7 +114,7 @@
'Unary')
self.assertEqual(
self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
- 'pw.test2.Alpha/Unary')
+ 'pw.test2.Alpha.Unary')
def test_iterate_over_all_methods(self):
channel_client = self._channel_client
diff --git a/pw_rpc/py/pw_rpc/__init__.py b/pw_rpc/py/pw_rpc/__init__.py
index e69de29..d7a70c7 100644
--- a/pw_rpc/py/pw_rpc/__init__.py
+++ b/pw_rpc/py/pw_rpc/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Package for calling Pigweed RPCs from Python."""
+
+from pw_rpc.client import Client
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index da98a42..33e7d51 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -88,6 +88,9 @@
override_pending=True)
return _AsyncCall(self._rpcs, self._rpc)
+ def __repr__(self) -> str:
+ return repr(self.method)
+
class _AsyncCall:
"""Represents an ongoing callback-based call."""
@@ -129,7 +132,7 @@
class UnaryMethodClient(_MethodClient):
def __call__(self, _request=None, **request_fields) -> Tuple[Status, Any]:
responses: queue.SimpleQueue = queue.SimpleQueue()
- self.invoke(
+ self.reinvoke(
lambda _, status, payload: responses.put((status, payload)),
_request, **request_fields)
return responses.get()
@@ -138,7 +141,7 @@
class ServerStreamingMethodClient(_MethodClient):
def __call__(self, _request=None, **request_fields) -> _StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
- self.invoke(
+ self.reinvoke(
lambda _, status, payload: responses.put((status, payload)),
_request, **request_fields)
return _StreamingResponses(responses)
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index e009398..a7c2920 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -139,14 +139,25 @@
class _MethodClients(descriptors.ServiceAccessor):
"""Navigates the methods in a service provided by a ChannelClient."""
def __init__(self, rpcs: PendingRpcs, client_impl: ClientImpl,
- channel: Channel, methods: Collection[Method]):
+ channel: Channel, service: Service):
super().__init__(
{
method: client_impl.method_client(rpcs, channel, method)
- for method in methods
+ for method in service.methods
},
as_attrs='members')
+ self._channel = channel
+ self._service = service
+
+ def __repr__(self) -> str:
+ return (f'Service({self._service.full_name!r}, '
+ f'methods={[m.name for m in self._service.methods]}, '
+ f'channel={self._channel.id})')
+
+ def __str__(self) -> str:
+ return str(self._service)
+
class _ServiceClients(descriptors.ServiceAccessor[_MethodClients]):
"""Navigates the services provided by a ChannelClient."""
@@ -154,11 +165,18 @@
services: Collection[Service]):
super().__init__(
{
- s: _MethodClients(rpcs, client_impl, channel, s.methods)
+ s: _MethodClients(rpcs, client_impl, channel, s)
for s in services
},
as_attrs='packages')
+ self._channel = channel
+ self._services = services
+
+ def __repr__(self) -> str:
+ return (f'Services(channel={self._channel.id}, '
+ f'services={[s.full_name for s in self._services]})')
+
def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
# Server streaming RPC packets never have a status; all other packets do.
@@ -227,11 +245,18 @@
"""
return descriptors.get_method(self.rpcs, method_name)
+ def services(self) -> Iterator:
+ return iter(self.rpcs)
+
def methods(self) -> Iterator:
"""Iterates over all method clients in this ChannelClient."""
for service_client in self.rpcs:
yield from service_client
+ def __repr__(self) -> str:
+ return (f'ChannelClient(channel={self.channel.id}, '
+ f'services={[str(s) for s in self.services()]})')
+
class Client:
"""Sends requests and handles responses for a set of channels.
@@ -370,3 +395,7 @@
f'No method ID {packet.method_id} in service {service.name}')
return PendingRpc(channel_client.channel, service, method)
+
+ def __repr__(self) -> str:
+ return (f'pw_rpc.Client(channels={list(self._channels_by_id)}, '
+ f'services={[s.full_name for s in self.services]})')
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index 1713960..ce94b32 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -107,7 +107,7 @@
@property
def full_name(self) -> str:
- return f'{self.service.full_name}/{self.name}'
+ return f'{self.service.full_name}.{self.name}'
@property
def type(self) -> 'Method.Type':
@@ -151,7 +151,18 @@
return proto
def __repr__(self) -> str:
- return f'Method({self.name!r})'
+ req = self._method_parameter(self.request_type, self.client_streaming)
+ res = self._method_parameter(self.response_type, self.server_streaming)
+ return f'<{self.full_name}({req}) returns ({res})>'
+
+ def _method_parameter(self, proto, streaming: bool) -> str:
+ """Returns a description of the method's request or response type."""
+ stream = 'stream ' if streaming else ''
+
+ if proto.DESCRIPTOR.file.package == self.service.package:
+ return stream + proto.DESCRIPTOR.name
+
+ return stream + proto.DESCRIPTOR.full_name
def __str__(self) -> str:
return self.full_name