pw_rpc: Python client improvements
- Add separate methods for response, completion, and error to the
ClientImpl interface. This allows explicit handling of errors by the
client.
- In the callback_client impl, raise exceptions for server errors.
- Simplify the callback_client's invoke method.
- Add pw_rpc_timeout_s to RPC method tab completion.
Change-Id: I1a2214057147c0d1b2547973da1b775f76e36d60
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/31582
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/py/callback_client_test.py b/pw_rpc/py/callback_client_test.py
index 425fd0b..ed41b3e 100755
--- a/pw_rpc/py/callback_client_test.py
+++ b/pw_rpc/py/callback_client_test.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
@@ -62,6 +62,7 @@
"""Tests the callback_client as used within a pw_rpc Client."""
def setUp(self):
self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
+ self._request = self._protos.packages.pw.test1.SomeMessage
self._client = client.Client.from_modules(
callback_client.Impl(), [client.Channel(1, self._handle_request)],
@@ -114,6 +115,19 @@
status=status.value).SerializeToString(),
process_status))
+ def _enqueue_error(self,
+ channel_id: int,
+ method,
+ status: Status,
+ process_status=Status.OK):
+ self._next_packets.append(
+ (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
+ channel_id=channel_id,
+ service_id=method.service.id,
+ method_id=method.id,
+ status=status.value).SerializeToString(),
+ process_status))
+
def _handle_request(self, data: bytes):
# Disable this method to prevent infinite recursion if processing the
# packet happens to send another packet.
@@ -143,7 +157,8 @@
self._enqueue_response(1, method, Status.ABORTED,
method.response_type(payload='0_o'))
- status, response = self._service.SomeUnary(magic_number=6)
+ status, response = self._service.SomeUnary(
+ method.request_type(magic_number=6))
self.assertEqual(
6,
@@ -160,24 +175,39 @@
method.response_type(payload='0_o'))
callback = mock.Mock()
- self._service.SomeUnary.invoke(callback, magic_number=5)
+ self._service.SomeUnary.invoke(self._request(magic_number=5),
+ callback, callback)
- callback.assert_called_once_with(
- _rpc(self._service.SomeUnary), Status.ABORTED,
- method.response_type(payload='0_o'))
+ callback.assert_has_calls([
+ mock.call(_rpc(self._service.SomeUnary),
+ method.response_type(payload='0_o')),
+ mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
+ ])
self.assertEqual(
5,
self._sent_payload(method.request_type).magic_number)
- def test_invoke_unary_rpc_callback_errors_suppressed(self):
+ def test_unary_rpc_server_error(self):
+ method = self._service.SomeUnary.method
+
+ for _ in range(3):
+ self._enqueue_error(1, method, Status.NOT_FOUND)
+
+ with self.assertRaises(callback_client.RpcError) as context:
+ self._service.SomeUnary(method.request_type(magic_number=6))
+
+ self.assertIs(context.exception.status, Status.NOT_FOUND)
+
+ def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
stub = self._service.SomeUnary
self._enqueue_response(1, stub.method)
exception_msg = 'YOU BROKE IT O-]-<'
with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
- stub.invoke(mock.Mock(side_effect=Exception(exception_msg)))
+ stub.invoke(self._request(),
+ mock.Mock(side_effect=Exception(exception_msg)))
self.assertIn(exception_msg, ''.join(logs.output))
@@ -190,14 +220,18 @@
callback = mock.Mock()
for _ in range(3):
- call = self._service.SomeUnary.invoke(callback, magic_number=55)
+ call = self._service.SomeUnary.invoke(
+ self._request(magic_number=55), callback)
self.assertIsNotNone(self._last_request)
self._last_request = None
- # Try to invoke the RPC again before cancelling, which is an error.
+ # Try to invoke the RPC again before cancelling, without overriding
+ # pending RPCs.
with self.assertRaises(client.Error):
- self._service.SomeUnary.invoke(callback, magic_number=56)
+ self._service.SomeUnary.invoke(self._request(magic_number=56),
+ callback,
+ override_pending=False)
self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) # Already cancelled, returns False
@@ -208,12 +242,10 @@
callback.assert_not_called()
def test_reinvoke_unary_rpc(self):
- callback = mock.Mock()
-
- # The reinvoke method ignores pending rpcs, so can be called repeatedly.
for _ in range(3):
self._last_request = None
- self._service.SomeUnary.reinvoke(callback, magic_number=55)
+ self._service.SomeUnary.invoke(self._request(magic_number=55),
+ override_pending=True)
self.assertEqual(self._last_request.type,
packet_pb2.PacketType.REQUEST)
@@ -236,7 +268,7 @@
4,
self._sent_payload(method.request_type).magic_number)
- def test_invoke_server_streaming_with_callback(self):
+ def test_invoke_server_streaming_with_callbacks(self):
method = self._service.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
@@ -248,13 +280,14 @@
self._enqueue_stream_end(1, method, Status.ABORTED)
callback = mock.Mock()
- self._service.SomeServerStreaming.invoke(callback, magic_number=3)
+ self._service.SomeServerStreaming.invoke(
+ self._request(magic_number=3), callback, callback)
rpc = _rpc(self._service.SomeServerStreaming)
callback.assert_has_calls([
- mock.call(rpc, None, method.response_type(payload='!!!')),
- mock.call(rpc, None, method.response_type(payload='?')),
- mock.call(rpc, Status.ABORTED, None),
+ mock.call(rpc, method.response_type(payload='!!!')),
+ mock.call(rpc, method.response_type(payload='?')),
+ mock.call(rpc, Status.ABORTED),
])
self.assertEqual(
@@ -268,9 +301,9 @@
self._enqueue_response(1, stub.method, response=resp)
callback = mock.Mock()
- call = stub.invoke(callback, magic_number=3)
+ call = stub.invoke(self._request(magic_number=3), callback)
callback.assert_called_once_with(
- _rpc(stub), None, stub.method.response_type(payload='!!!'))
+ _rpc(stub), stub.method.response_type(payload='!!!'))
callback.reset_mock()
@@ -283,17 +316,15 @@
self._enqueue_response(1, stub.method, response=resp)
self._enqueue_stream_end(1, stub.method, Status.OK)
- call = stub.invoke(callback, magic_number=3)
+ call = stub.invoke(self._request(magic_number=3), callback, callback)
- rpc = _rpc(stub)
callback.assert_has_calls([
- mock.call(rpc, None, stub.method.response_type(payload='!!!')),
- mock.call(rpc, Status.OK, None),
+ mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
+ mock.call(_rpc(stub), Status.OK),
])
def test_ignore_bad_packets_with_pending_rpc(self):
- rpcs = self._client.channel(1).rpcs
- method = rpcs.pw.test1.PublicService.SomeUnary.method
+ method = self._service.SomeUnary.method
service_id = method.service.id
# Unknown channel
@@ -307,22 +338,19 @@
ids=(service_id, 999),
process_status=Status.OK)
# For RPC not pending (is Status.OK because the packet is processed)
- self._enqueue_response(
- 1,
- ids=(service_id,
- rpcs.pw.test1.PublicService.SomeBidiStreaming.method.id),
- process_status=Status.OK)
+ self._enqueue_response(1,
+ ids=(service_id,
+ self._service.SomeBidiStreaming.method.id),
+ process_status=Status.OK)
self._enqueue_response(1, method, process_status=Status.OK)
- status, response = rpcs.pw.test1.PublicService.SomeUnary(
- magic_number=6)
+ status, response = self._service.SomeUnary(magic_number=6)
self.assertIs(Status.OK, status)
self.assertEqual('', response.payload)
def test_pass_none_if_payload_fails_to_decode(self):
- rpcs = self._client.channel(1).rpcs
- method = rpcs.pw.test1.PublicService.SomeUnary.method
+ method = self._service.SomeUnary.method
self._enqueue_response(1,
method,
@@ -330,13 +358,12 @@
b'INVALID DATA!!!',
process_status=Status.OK)
- status, response = rpcs.pw.test1.PublicService.SomeUnary(
- magic_number=6)
+ status, response = self._service.SomeUnary(magic_number=6)
self.assertIs(status, Status.OK)
self.assertIsNone(response)
def test_rpc_help_contains_method_name(self):
- rpc = self._client.channel(1).rpcs.pw.test1.PublicService.SomeUnary
+ rpc = self._service.SomeUnary
self.assertIn(rpc.method.full_name, rpc.help())
def test_default_timeouts_set_on_impl(self):
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index a332458..057bae2 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
@@ -11,7 +11,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
-"""Defines a callback-based RPC ClientImpl to use with pw_rpc.client.Client.
+"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
callback_client.Impl supports invoking RPCs synchronously or asynchronously.
Asynchronous invocations use a callback.
@@ -45,12 +45,14 @@
import logging
import queue
import textwrap
-from typing import Any, Callable, NamedTuple, Union, Optional
-
-from pw_status import Status
+import threading
+from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
from pw_protobuf_compiler.python_protos import proto_repr
+from pw_status import Status
+
from pw_rpc import client, descriptors
+from pw_rpc.client import PendingRpc, PendingRpcs
from pw_rpc.descriptors import Channel, Method, Service
_LOG = logging.getLogger(__name__)
@@ -63,17 +65,37 @@
OptionalTimeout = Union[UseDefault, float, None]
-Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
+ResponseCallback = Callable[[PendingRpc, Any], Any]
+CompletionCallback = Callable[[PendingRpc, Any], Any]
+ErrorCallback = Callable[[PendingRpc, Status], Any]
+
+
+class _Callbacks(NamedTuple):
+ response: ResponseCallback
+ completion: CompletionCallback
+ error: ErrorCallback
+
+
+def _default_response(rpc: PendingRpc, response: Any) -> None:
+ _LOG.info('%s response: %s', rpc, response)
+
+
+def _default_completion(rpc: PendingRpc, status: Status) -> None:
+ _LOG.info('%s finished: %s', rpc, status)
+
+
+def _default_error(rpc: PendingRpc, status: Status) -> None:
+ _LOG.error('%s error: %s', rpc, status)
class _MethodClient:
"""A method that can be invoked for a particular channel."""
- def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs,
+ def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
channel: Channel, method: Method,
default_timeout_s: Optional[float]):
self._impl = client_impl
self._rpcs = rpcs
- self._rpc = client.PendingRpc(channel, method.service, method)
+ self._rpc = PendingRpc(channel, method.service, method)
self.default_timeout_s: Optional[float] = default_timeout_s
@property
@@ -88,22 +110,17 @@
def service(self) -> Service:
return self._rpc.service
- def invoke(self, callback: Callback, _request=None, **request_fields):
- """Invokes an RPC with a callback."""
+ def invoke(self,
+ request: Any,
+ response: ResponseCallback = _default_response,
+ completion: CompletionCallback = _default_completion,
+ error: ErrorCallback = _default_error,
+ override_pending: bool = True) -> '_AsyncCall':
+ """Invokes an RPC with callbacks."""
self._rpcs.send_request(self._rpc,
- self.method.get_request(
- _request, request_fields),
- callback,
- override_pending=False)
- return _AsyncCall(self._rpcs, self._rpc)
-
- def reinvoke(self, callback: Callback, _request=None, **request_fields):
- """Invokes an RPC with a callback, overriding any pending requests."""
- self._rpcs.send_request(self._rpc,
- self.method.get_request(
- _request, request_fields),
- callback,
- override_pending=True)
+ request,
+ _Callbacks(response, completion, error),
+ override_pending=override_pending)
return _AsyncCall(self._rpcs, self._rpc)
def __repr__(self) -> str:
@@ -132,24 +149,36 @@
class RpcTimeout(Exception):
- def __init__(self, rpc: client.PendingRpc, timeout: Optional[float]):
+ def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
super().__init__(
f'No response received for {rpc.method} after {timeout} s')
self.rpc = rpc
self.timeout = timeout
+class RpcError(Exception):
+ def __init__(self, rpc: PendingRpc, status: Status):
+ if status is Status.NOT_FOUND:
+ msg = ': the RPC server does not support this RPC'
+ else:
+ msg = ''
+
+ super().__init__(f'{rpc.method} failed with error {status}{msg}')
+ self.rpc = rpc
+ self.status = status
+
+
class _AsyncCall:
"""Represents an ongoing callback-based call."""
# TODO(hepler): Consider alternatives (futures) and/or expand functionality.
- def __init__(self, rpcs: client.PendingRpcs, rpc: client.PendingRpc):
- self.rpc = rpc
+ def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
+ self._rpc = rpc
self._rpcs = rpcs
def cancel(self) -> bool:
- return self._rpcs.send_cancel(self.rpc)
+ return self._rpcs.send_cancel(self._rpc)
def __enter__(self) -> '_AsyncCall':
return self
@@ -179,7 +208,7 @@
def responses(self,
*,
block: bool = True,
- timeout_s: OptionalTimeout = UseDefault.VALUE):
+ timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
"""Returns an iterator of stream responses.
Args:
@@ -190,8 +219,13 @@
try:
while True:
- self.status, response = self._queue.get(block, timeout_s)
- if self.status is not None:
+ response = self._queue.get(block, timeout_s)
+
+ if isinstance(response, Exception):
+ raise response
+
+ if isinstance(response, Status):
+ self.status = response
return
yield response
@@ -212,7 +246,7 @@
Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
Calling this directly invokes the RPC synchronously. The RPC can be invoked
-asynchronously using the invoke or reinvoke methods.
+asynchronously using the invoke method.
'''
@@ -238,6 +272,8 @@
params = [next(iter(sig.parameters.values()))] # Get the "self" parameter
params += method.request_parameters()
+ params.append(
+ inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
function.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=params)
@@ -251,32 +287,56 @@
return f'({self.status}, {proto_repr(self.response)})'
-def unary_method_client(client_impl: 'Impl', rpcs: client.PendingRpcs,
- channel: Channel, method: Method,
- default_timeout: Optional[float]) -> _MethodClient:
+class _UnaryResponseHandler:
+ """Tracks the state of an ongoing synchronous unary RPC call."""
+ def __init__(self, rpc: PendingRpc):
+ self._rpc = rpc
+ self._response: Any = None
+ self._status: Optional[Status] = None
+ self._error: Optional[RpcError] = None
+ self._event = threading.Event()
+
+ def on_response(self, _: PendingRpc, response: Any) -> None:
+ self._response = response
+
+ def on_completion(self, _: PendingRpc, status: Status) -> None:
+ self._status = status
+ self._event.set()
+
+ def on_error(self, _: PendingRpc, status: Status) -> None:
+ self._error = RpcError(self._rpc, status)
+ self._event.set()
+
+ def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
+ if not self._event.wait(timeout_s):
+ raise RpcTimeout(self._rpc, timeout_s)
+
+ if self._error is not None:
+ raise self._error
+
+ assert self._status is not None
+ return UnaryResponse(self._status, self._response)
+
+
+def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
+ channel: Channel, method: Method,
+ default_timeout: Optional[float]) -> _MethodClient:
"""Creates an object used to call a unary method."""
def call(self: _MethodClient,
_rpc_request_proto=None,
*,
pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> UnaryResponse:
- responses: queue.SimpleQueue = queue.SimpleQueue()
- def enqueue_response(_, status: Optional[Status], payload: Any):
- assert isinstance(status, Status)
- responses.put(UnaryResponse(status, payload))
-
- self.reinvoke(enqueue_response, _rpc_request_proto, **request_fields)
+ handler = _UnaryResponseHandler(self._rpc) # pylint: disable=protected-access
+ self.invoke(
+ self.method.get_request(_rpc_request_proto, request_fields),
+ handler.on_response, handler.on_completion, handler.on_error)
if pw_rpc_timeout_s is UseDefault.VALUE:
pw_rpc_timeout_s = self.default_timeout_s
- try:
- return responses.get(timeout=pw_rpc_timeout_s)
- except queue.Empty:
- pass
-
- raise RpcTimeout(self._rpc, pw_rpc_timeout_s) # pylint: disable=protected-access
+ return handler.wait(pw_rpc_timeout_s)
_update_function_signature(method, call)
@@ -289,10 +349,9 @@
default_timeout)
-def server_streaming_method_client(client_impl: 'Impl',
- rpcs: client.PendingRpcs, channel: Channel,
- method: Method,
- default_timeout: Optional[float]):
+def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
+ channel: Channel, method: Method,
+ default_timeout: Optional[float]):
"""Creates an object used to call a server streaming method."""
def call(self: _MethodClient,
_rpc_request_proto=None,
@@ -300,9 +359,11 @@
pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
- self.reinvoke(
- lambda _, status, payload: responses.put((status, payload)),
- _rpc_request_proto, **request_fields)
+ self.invoke(
+ self.method.get_request(_rpc_request_proto, request_fields),
+ lambda _, response: responses.put(response),
+ lambda _, status: responses.put(status),
+ lambda rpc, status: responses.put(RpcError(rpc, status)))
return StreamingResponses(self, responses, pw_rpc_timeout_s)
_update_function_signature(method, call)
@@ -320,7 +381,12 @@
def __call__(self):
raise NotImplementedError
- def invoke(self, callback: Callback, _request=None, **request_fields):
+ def invoke(self,
+ request: Any,
+ response: ResponseCallback = _default_response,
+ completion: CompletionCallback = _default_completion,
+ error: ErrorCallback = _default_error,
+ override_pending: bool = True) -> _AsyncCall:
raise NotImplementedError
@@ -328,15 +394,21 @@
def __call__(self):
raise NotImplementedError
- def invoke(self, callback: Callback, _request=None, **request_fields):
+ def invoke(self,
+ request: Any,
+ response: ResponseCallback = _default_response,
+ completion: CompletionCallback = _default_completion,
+ error: ErrorCallback = _default_error,
+ override_pending: bool = True) -> _AsyncCall:
raise NotImplementedError
class Impl(client.ClientImpl):
- """Callback-based client.ClientImpl."""
+ """Callback-based ClientImpl."""
def __init__(self,
default_unary_timeout_s: Optional[float] = 1.0,
default_stream_timeout_s: Optional[float] = 1.0):
+ super().__init__()
self._default_unary_timeout_s = default_unary_timeout_s
self._default_stream_timeout_s = default_stream_timeout_s
@@ -348,37 +420,37 @@
def default_stream_timeout_s(self) -> Optional[float]:
return self._default_stream_timeout_s
- def method_client(self, rpcs: client.PendingRpcs, channel: Channel,
- method: Method) -> _MethodClient:
+ def method_client(self, channel: Channel, method: Method) -> _MethodClient:
"""Returns an object that invokes a method using the given chanel."""
if method.type is Method.Type.UNARY:
- return unary_method_client(self, rpcs, channel, method,
- self.default_unary_timeout_s)
+ return _unary_method_client(self, self.rpcs, channel, method,
+ self.default_unary_timeout_s)
if method.type is Method.Type.SERVER_STREAMING:
- return server_streaming_method_client(
- self, rpcs, channel, method, self.default_stream_timeout_s)
+ return _server_streaming_method_client(
+ self, self.rpcs, channel, method,
+ self.default_stream_timeout_s)
if method.type is Method.Type.CLIENT_STREAMING:
- return ClientStreamingMethodClient(self, rpcs, channel, method,
+ return ClientStreamingMethodClient(self, self.rpcs, channel,
+ method,
self.default_unary_timeout_s)
if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
return BidirectionalStreamingMethodClient(
- self, rpcs, channel, method, self.default_stream_timeout_s)
+ self, self.rpcs, channel, method,
+ self.default_stream_timeout_s)
raise AssertionError(f'Unknown method type {method.type}')
- def process_response(self,
- rpcs: client.PendingRpcs,
- rpc: client.PendingRpc,
- context,
- status: Optional[Status],
- payload,
- *,
- args: tuple = (),
- kwargs: dict = None) -> None:
+ def handle_response(self,
+ rpc: PendingRpc,
+ context,
+ payload,
+ *,
+ args: tuple = (),
+ kwargs: dict = None) -> None:
"""Invokes the callback associated with this RPC.
Any additional positional and keyword args passed through
@@ -388,7 +460,40 @@
kwargs = {}
try:
- context(rpc, status, payload, *args, **kwargs)
+ context.response(rpc, payload, *args, **kwargs)
except: # pylint: disable=bare-except
- rpcs.send_cancel(rpc)
- _LOG.exception('Callback %s for %s raised exception', context, rpc)
+ self.rpcs.send_cancel(rpc)
+ _LOG.exception('Response callback %s for %s raised exception',
+ context.response, rpc)
+
+ def handle_completion(self,
+ rpc: PendingRpc,
+ context,
+ status: Status,
+ *,
+ args: tuple = (),
+ kwargs: dict = None):
+ if kwargs is None:
+ kwargs = {}
+
+ try:
+ context.completion(rpc, status, *args, **kwargs)
+ except: # pylint: disable=bare-except
+ _LOG.exception('Completion callback %s for %s raised exception',
+ context.completion, rpc)
+
+ def handle_error(self,
+ rpc: PendingRpc,
+ context,
+ status: Status,
+ *,
+ args: tuple = (),
+ kwargs: dict = None) -> None:
+ if kwargs is None:
+ kwargs = {}
+
+ try:
+ context.error(rpc, status, *args, **kwargs)
+ except: # pylint: disable=bare-except
+ _LOG.exception('Error callback %s for %s raised exception',
+ context.error, rpc)
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index 2fd7ee1..d82a4de 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -16,8 +16,8 @@
import abc
from dataclasses import dataclass
import logging
-from typing import Collection, Dict, Iterable, Iterator, List, NamedTuple
-from typing import Optional
+from typing import (Any, Collection, Dict, Iterable, Iterator, List,
+ NamedTuple, Optional)
from google.protobuf.message import DecodeError
from pw_rpc_protos.internal.packet_pb2 import PacketType, RpcPacket
@@ -122,39 +122,73 @@
This interface defines the semantics for invoking an RPC on a particular
client.
"""
+ def __init__(self):
+ self.client: 'Client' = None
+ self.rpcs: PendingRpcs = None
+
@abc.abstractmethod
- def method_client(self, rpcs: PendingRpcs, channel: Channel,
- method: Method):
+ def method_client(self, channel: Channel, method: Method) -> Any:
"""Returns an object that invokes a method using the given channel."""
@abc.abstractmethod
- def process_response(self,
- rpcs: PendingRpcs,
- rpc: PendingRpc,
- context,
- status: Optional[Status],
- payload,
- *,
- args: tuple = (),
- kwargs: dict = None) -> None:
- """Processes a response from the RPC server.
+ def handle_response(self,
+ rpc: PendingRpc,
+ context: Any,
+ payload: Any,
+ *,
+ args: tuple = (),
+ kwargs: dict = None) -> Any:
+ """Handles a response from the RPC server.
Args:
- rpcs: The PendingRpcs object used by the client.
rpc: Information about the pending RPC
context: Arbitrary context object associated with the pending RPC
- status: If set, this is the last packet for this RPC. None otherwise.
- payload: A protobuf message, if present. None otherwise.
+ payload: A protobuf message
+ args, kwargs: Arbitrary arguments passed to the ClientImpl
+ """
+
+ @abc.abstractmethod
+ def handle_completion(self,
+ rpc: PendingRpc,
+ context: Any,
+ status: Status,
+ *,
+ args: tuple = (),
+ kwargs: dict = None) -> Any:
+ """Handles the successful completion of an RPC.
+
+ Args:
+ rpc: Information about the pending RPC
+ context: Arbitrary context object associated with the pending RPC
+ status: Status returned from the RPC
+ args, kwargs: Arbitrary arguments passed to the ClientImpl
+ """
+
+ @abc.abstractmethod
+ def handle_error(self,
+ rpc: PendingRpc,
+ context,
+ status: Status,
+ *,
+ args: tuple = (),
+ kwargs: dict = None):
+ """Handles the abnormal termination of an RPC.
+
+ args:
+ rpc: Information about the pending RPC
+ context: Arbitrary context object associated with the pending RPC
+ status: which error occurred
+ args, kwargs: Arbitrary arguments passed to the ClientImpl
"""
class ServiceClient(descriptors.ServiceAccessor):
"""Navigates the methods in a service provided by a ChannelClient."""
- def __init__(self, rpcs: PendingRpcs, client_impl: ClientImpl,
- channel: Channel, service: Service):
+ def __init__(self, client_impl: ClientImpl, channel: Channel,
+ service: Service):
super().__init__(
{
- method: client_impl.method_client(rpcs, channel, method)
+ method: client_impl.method_client(channel, method)
for method in service.methods
},
as_attrs='members')
@@ -173,13 +207,11 @@
class Services(descriptors.ServiceAccessor[ServiceClient]):
"""Navigates the services provided by a ChannelClient."""
- def __init__(self, rpcs: PendingRpcs, client_impl, channel: Channel,
+ def __init__(self, client_impl, channel: Channel,
services: Collection[Service]):
super().__init__(
- {
- s: ServiceClient(rpcs, client_impl, channel, s)
- for s in services
- },
+ {s: ServiceClient(client_impl, channel, s)
+ for s in services},
as_attrs='packages')
self._channel = channel
@@ -286,15 +318,15 @@
def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
services: Iterable[Service]):
self._impl = impl
+ self._impl.client = self
+ self._impl.rpcs = PendingRpcs()
self.services = descriptors.Services(services)
- self._rpcs = PendingRpcs()
-
self._channels_by_id = {
- channel.id: ChannelClient(
- self, channel,
- Services(self._rpcs, self._impl, channel, self.services))
+ channel.id:
+ ChannelClient(self, channel,
+ Services(self._impl, channel, self.services))
for channel in channels
}
@@ -374,7 +406,7 @@
payload = _decode_payload(rpc, packet)
try:
- context = self._rpcs.get_pending(rpc, status)
+ context = self._impl.rpcs.get_pending(rpc, status)
except KeyError:
channel_client.channel.output( # type: ignore
packets.encode_client_error(packet,
@@ -383,18 +415,28 @@
return Status.OK
if packet.type == PacketType.SERVER_ERROR:
+ assert status is not None and not status.ok()
_LOG.warning('%s: invocation failed with %s', rpc, status)
-
- # Do not return yet -- call process_response so the ClientImpl can
- # do any necessary cleanup.
-
- self._impl.process_response(self._rpcs,
- rpc,
+ self._impl.handle_error(rpc,
context,
status,
- payload,
args=impl_args,
kwargs=impl_kwargs)
+ return Status.OK
+
+ if payload is not None:
+ self._impl.handle_response(rpc,
+ context,
+ payload,
+ args=impl_args,
+ kwargs=impl_kwargs)
+ if status is not None:
+ self._impl.handle_completion(rpc,
+ context,
+ status,
+ args=impl_args,
+ kwargs=impl_kwargs)
+
return Status.OK
def _look_up_service_and_method(