pw_rpc: Python client improvements

- Add separate methods for response, completion, and error to the
  ClientImpl interface. This allows explicit handling of errors by the
  client.
- In the callback_client impl, raise exceptions for server errors.
- Simplify the callback_client's invoke method.
- Add pw_rpc_timeout_s to RPC method tab completion.

Change-Id: I1a2214057147c0d1b2547973da1b775f76e36d60
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/31582
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/py/callback_client_test.py b/pw_rpc/py/callback_client_test.py
index 425fd0b..ed41b3e 100755
--- a/pw_rpc/py/callback_client_test.py
+++ b/pw_rpc/py/callback_client_test.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
 #
 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
 # use this file except in compliance with the License. You may obtain a copy of
@@ -62,6 +62,7 @@
     """Tests the callback_client as used within a pw_rpc Client."""
     def setUp(self):
         self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
+        self._request = self._protos.packages.pw.test1.SomeMessage
 
         self._client = client.Client.from_modules(
             callback_client.Impl(), [client.Channel(1, self._handle_request)],
@@ -114,6 +115,19 @@
                                   status=status.value).SerializeToString(),
              process_status))
 
+    def _enqueue_error(self,
+                       channel_id: int,
+                       method,
+                       status: Status,
+                       process_status=Status.OK):
+        self._next_packets.append(
+            (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
+                                  channel_id=channel_id,
+                                  service_id=method.service.id,
+                                  method_id=method.id,
+                                  status=status.value).SerializeToString(),
+             process_status))
+
     def _handle_request(self, data: bytes):
         # Disable this method to prevent infinite recursion if processing the
         # packet happens to send another packet.
@@ -143,7 +157,8 @@
             self._enqueue_response(1, method, Status.ABORTED,
                                    method.response_type(payload='0_o'))
 
-            status, response = self._service.SomeUnary(magic_number=6)
+            status, response = self._service.SomeUnary(
+                method.request_type(magic_number=6))
 
             self.assertEqual(
                 6,
@@ -160,24 +175,39 @@
                                    method.response_type(payload='0_o'))
 
             callback = mock.Mock()
-            self._service.SomeUnary.invoke(callback, magic_number=5)
+            self._service.SomeUnary.invoke(self._request(magic_number=5),
+                                           callback, callback)
 
-            callback.assert_called_once_with(
-                _rpc(self._service.SomeUnary), Status.ABORTED,
-                method.response_type(payload='0_o'))
+            callback.assert_has_calls([
+                mock.call(_rpc(self._service.SomeUnary),
+                          method.response_type(payload='0_o')),
+                mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
+            ])
 
             self.assertEqual(
                 5,
                 self._sent_payload(method.request_type).magic_number)
 
-    def test_invoke_unary_rpc_callback_errors_suppressed(self):
+    def test_unary_rpc_server_error(self):
+        method = self._service.SomeUnary.method
+
+        for _ in range(3):
+            self._enqueue_error(1, method, Status.NOT_FOUND)
+
+            with self.assertRaises(callback_client.RpcError) as context:
+                self._service.SomeUnary(method.request_type(magic_number=6))
+
+            self.assertIs(context.exception.status, Status.NOT_FOUND)
+
+    def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
         stub = self._service.SomeUnary
 
         self._enqueue_response(1, stub.method)
         exception_msg = 'YOU BROKE IT O-]-<'
 
         with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
-            stub.invoke(mock.Mock(side_effect=Exception(exception_msg)))
+            stub.invoke(self._request(),
+                        mock.Mock(side_effect=Exception(exception_msg)))
 
         self.assertIn(exception_msg, ''.join(logs.output))
 
@@ -190,14 +220,18 @@
         callback = mock.Mock()
 
         for _ in range(3):
-            call = self._service.SomeUnary.invoke(callback, magic_number=55)
+            call = self._service.SomeUnary.invoke(
+                self._request(magic_number=55), callback)
 
             self.assertIsNotNone(self._last_request)
             self._last_request = None
 
-            # Try to invoke the RPC again before cancelling, which is an error.
+            # Try to invoke the RPC again before cancelling, without overriding
+            # pending RPCs.
             with self.assertRaises(client.Error):
-                self._service.SomeUnary.invoke(callback, magic_number=56)
+                self._service.SomeUnary.invoke(self._request(magic_number=56),
+                                               callback,
+                                               override_pending=False)
 
             self.assertTrue(call.cancel())
             self.assertFalse(call.cancel())  # Already cancelled, returns False
@@ -208,12 +242,10 @@
         callback.assert_not_called()
 
     def test_reinvoke_unary_rpc(self):
-        callback = mock.Mock()
-
-        # The reinvoke method ignores pending rpcs, so can be called repeatedly.
         for _ in range(3):
             self._last_request = None
-            self._service.SomeUnary.reinvoke(callback, magic_number=55)
+            self._service.SomeUnary.invoke(self._request(magic_number=55),
+                                           override_pending=True)
             self.assertEqual(self._last_request.type,
                              packet_pb2.PacketType.REQUEST)
 
@@ -236,7 +268,7 @@
                 4,
                 self._sent_payload(method.request_type).magic_number)
 
-    def test_invoke_server_streaming_with_callback(self):
+    def test_invoke_server_streaming_with_callbacks(self):
         method = self._service.SomeServerStreaming.method
 
         rep1 = method.response_type(payload='!!!')
@@ -248,13 +280,14 @@
             self._enqueue_stream_end(1, method, Status.ABORTED)
 
             callback = mock.Mock()
-            self._service.SomeServerStreaming.invoke(callback, magic_number=3)
+            self._service.SomeServerStreaming.invoke(
+                self._request(magic_number=3), callback, callback)
 
             rpc = _rpc(self._service.SomeServerStreaming)
             callback.assert_has_calls([
-                mock.call(rpc, None, method.response_type(payload='!!!')),
-                mock.call(rpc, None, method.response_type(payload='?')),
-                mock.call(rpc, Status.ABORTED, None),
+                mock.call(rpc, method.response_type(payload='!!!')),
+                mock.call(rpc, method.response_type(payload='?')),
+                mock.call(rpc, Status.ABORTED),
             ])
 
             self.assertEqual(
@@ -268,9 +301,9 @@
         self._enqueue_response(1, stub.method, response=resp)
 
         callback = mock.Mock()
-        call = stub.invoke(callback, magic_number=3)
+        call = stub.invoke(self._request(magic_number=3), callback)
         callback.assert_called_once_with(
-            _rpc(stub), None, stub.method.response_type(payload='!!!'))
+            _rpc(stub), stub.method.response_type(payload='!!!'))
 
         callback.reset_mock()
 
@@ -283,17 +316,15 @@
         self._enqueue_response(1, stub.method, response=resp)
         self._enqueue_stream_end(1, stub.method, Status.OK)
 
-        call = stub.invoke(callback, magic_number=3)
+        call = stub.invoke(self._request(magic_number=3), callback, callback)
 
-        rpc = _rpc(stub)
         callback.assert_has_calls([
-            mock.call(rpc, None, stub.method.response_type(payload='!!!')),
-            mock.call(rpc, Status.OK, None),
+            mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
+            mock.call(_rpc(stub), Status.OK),
         ])
 
     def test_ignore_bad_packets_with_pending_rpc(self):
-        rpcs = self._client.channel(1).rpcs
-        method = rpcs.pw.test1.PublicService.SomeUnary.method
+        method = self._service.SomeUnary.method
         service_id = method.service.id
 
         # Unknown channel
@@ -307,22 +338,19 @@
                                ids=(service_id, 999),
                                process_status=Status.OK)
         # For RPC not pending (is Status.OK because the packet is processed)
-        self._enqueue_response(
-            1,
-            ids=(service_id,
-                 rpcs.pw.test1.PublicService.SomeBidiStreaming.method.id),
-            process_status=Status.OK)
+        self._enqueue_response(1,
+                               ids=(service_id,
+                                    self._service.SomeBidiStreaming.method.id),
+                               process_status=Status.OK)
 
         self._enqueue_response(1, method, process_status=Status.OK)
 
-        status, response = rpcs.pw.test1.PublicService.SomeUnary(
-            magic_number=6)
+        status, response = self._service.SomeUnary(magic_number=6)
         self.assertIs(Status.OK, status)
         self.assertEqual('', response.payload)
 
     def test_pass_none_if_payload_fails_to_decode(self):
-        rpcs = self._client.channel(1).rpcs
-        method = rpcs.pw.test1.PublicService.SomeUnary.method
+        method = self._service.SomeUnary.method
 
         self._enqueue_response(1,
                                method,
@@ -330,13 +358,12 @@
                                b'INVALID DATA!!!',
                                process_status=Status.OK)
 
-        status, response = rpcs.pw.test1.PublicService.SomeUnary(
-            magic_number=6)
+        status, response = self._service.SomeUnary(magic_number=6)
         self.assertIs(status, Status.OK)
         self.assertIsNone(response)
 
     def test_rpc_help_contains_method_name(self):
-        rpc = self._client.channel(1).rpcs.pw.test1.PublicService.SomeUnary
+        rpc = self._service.SomeUnary
         self.assertIn(rpc.method.full_name, rpc.help())
 
     def test_default_timeouts_set_on_impl(self):
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index a332458..057bae2 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Pigweed Authors
+# Copyright 2021 The Pigweed Authors
 #
 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
 # use this file except in compliance with the License. You may obtain a copy of
@@ -11,7 +11,7 @@
 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 # License for the specific language governing permissions and limitations under
 # the License.
-"""Defines a callback-based RPC ClientImpl to use with pw_rpc.client.Client.
+"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
 
 callback_client.Impl supports invoking RPCs synchronously or asynchronously.
 Asynchronous invocations use a callback.
@@ -45,12 +45,14 @@
 import logging
 import queue
 import textwrap
-from typing import Any, Callable, NamedTuple, Union, Optional
-
-from pw_status import Status
+import threading
+from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
 
 from pw_protobuf_compiler.python_protos import proto_repr
+from pw_status import Status
+
 from pw_rpc import client, descriptors
+from pw_rpc.client import PendingRpc, PendingRpcs
 from pw_rpc.descriptors import Channel, Method, Service
 
 _LOG = logging.getLogger(__name__)
@@ -63,17 +65,37 @@
 
 OptionalTimeout = Union[UseDefault, float, None]
 
-Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
+ResponseCallback = Callable[[PendingRpc, Any], Any]
+CompletionCallback = Callable[[PendingRpc, Any], Any]
+ErrorCallback = Callable[[PendingRpc, Status], Any]
+
+
+class _Callbacks(NamedTuple):
+    response: ResponseCallback
+    completion: CompletionCallback
+    error: ErrorCallback
+
+
+def _default_response(rpc: PendingRpc, response: Any) -> None:
+    _LOG.info('%s response: %s', rpc, response)
+
+
+def _default_completion(rpc: PendingRpc, status: Status) -> None:
+    _LOG.info('%s finished: %s', rpc, status)
+
+
+def _default_error(rpc: PendingRpc, status: Status) -> None:
+    _LOG.error('%s error: %s', rpc, status)
 
 
 class _MethodClient:
     """A method that can be invoked for a particular channel."""
-    def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs,
+    def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
                  channel: Channel, method: Method,
                  default_timeout_s: Optional[float]):
         self._impl = client_impl
         self._rpcs = rpcs
-        self._rpc = client.PendingRpc(channel, method.service, method)
+        self._rpc = PendingRpc(channel, method.service, method)
         self.default_timeout_s: Optional[float] = default_timeout_s
 
     @property
@@ -88,22 +110,17 @@
     def service(self) -> Service:
         return self._rpc.service
 
-    def invoke(self, callback: Callback, _request=None, **request_fields):
-        """Invokes an RPC with a callback."""
+    def invoke(self,
+               request: Any,
+               response: ResponseCallback = _default_response,
+               completion: CompletionCallback = _default_completion,
+               error: ErrorCallback = _default_error,
+               override_pending: bool = True) -> '_AsyncCall':
+        """Invokes an RPC with callbacks."""
         self._rpcs.send_request(self._rpc,
-                                self.method.get_request(
-                                    _request, request_fields),
-                                callback,
-                                override_pending=False)
-        return _AsyncCall(self._rpcs, self._rpc)
-
-    def reinvoke(self, callback: Callback, _request=None, **request_fields):
-        """Invokes an RPC with a callback, overriding any pending requests."""
-        self._rpcs.send_request(self._rpc,
-                                self.method.get_request(
-                                    _request, request_fields),
-                                callback,
-                                override_pending=True)
+                                request,
+                                _Callbacks(response, completion, error),
+                                override_pending=override_pending)
         return _AsyncCall(self._rpcs, self._rpc)
 
     def __repr__(self) -> str:
@@ -132,24 +149,36 @@
 
 
 class RpcTimeout(Exception):
-    def __init__(self, rpc: client.PendingRpc, timeout: Optional[float]):
+    def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
         super().__init__(
             f'No response received for {rpc.method} after {timeout} s')
         self.rpc = rpc
         self.timeout = timeout
 
 
+class RpcError(Exception):
+    def __init__(self, rpc: PendingRpc, status: Status):
+        if status is Status.NOT_FOUND:
+            msg = ': the RPC server does not support this RPC'
+        else:
+            msg = ''
+
+        super().__init__(f'{rpc.method} failed with error {status}{msg}')
+        self.rpc = rpc
+        self.status = status
+
+
 class _AsyncCall:
     """Represents an ongoing callback-based call."""
 
     # TODO(hepler): Consider alternatives (futures) and/or expand functionality.
 
-    def __init__(self, rpcs: client.PendingRpcs, rpc: client.PendingRpc):
-        self.rpc = rpc
+    def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
+        self._rpc = rpc
         self._rpcs = rpcs
 
     def cancel(self) -> bool:
-        return self._rpcs.send_cancel(self.rpc)
+        return self._rpcs.send_cancel(self._rpc)
 
     def __enter__(self) -> '_AsyncCall':
         return self
@@ -179,7 +208,7 @@
     def responses(self,
                   *,
                   block: bool = True,
-                  timeout_s: OptionalTimeout = UseDefault.VALUE):
+                  timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
         """Returns an iterator of stream responses.
 
         Args:
@@ -190,8 +219,13 @@
 
         try:
             while True:
-                self.status, response = self._queue.get(block, timeout_s)
-                if self.status is not None:
+                response = self._queue.get(block, timeout_s)
+
+                if isinstance(response, Exception):
+                    raise response
+
+                if isinstance(response, Status):
+                    self.status = response
                     return
 
                 yield response
@@ -212,7 +246,7 @@
 Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
 
 Calling this directly invokes the RPC synchronously. The RPC can be invoked
-asynchronously using the invoke or reinvoke methods.
+asynchronously using the invoke method.
 '''
 
 
@@ -238,6 +272,8 @@
 
     params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
     params += method.request_parameters()
+    params.append(
+        inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
     function.__signature__ = sig.replace(  # type: ignore[attr-defined]
         parameters=params)
 
@@ -251,32 +287,56 @@
         return f'({self.status}, {proto_repr(self.response)})'
 
 
-def unary_method_client(client_impl: 'Impl', rpcs: client.PendingRpcs,
-                        channel: Channel, method: Method,
-                        default_timeout: Optional[float]) -> _MethodClient:
+class _UnaryResponseHandler:
+    """Tracks the state of an ongoing synchronous unary RPC call."""
+    def __init__(self, rpc: PendingRpc):
+        self._rpc = rpc
+        self._response: Any = None
+        self._status: Optional[Status] = None
+        self._error: Optional[RpcError] = None
+        self._event = threading.Event()
+
+    def on_response(self, _: PendingRpc, response: Any) -> None:
+        self._response = response
+
+    def on_completion(self, _: PendingRpc, status: Status) -> None:
+        self._status = status
+        self._event.set()
+
+    def on_error(self, _: PendingRpc, status: Status) -> None:
+        self._error = RpcError(self._rpc, status)
+        self._event.set()
+
+    def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
+        if not self._event.wait(timeout_s):
+            raise RpcTimeout(self._rpc, timeout_s)
+
+        if self._error is not None:
+            raise self._error
+
+        assert self._status is not None
+        return UnaryResponse(self._status, self._response)
+
+
+def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
+                         channel: Channel, method: Method,
+                         default_timeout: Optional[float]) -> _MethodClient:
     """Creates an object used to call a unary method."""
     def call(self: _MethodClient,
              _rpc_request_proto=None,
              *,
              pw_rpc_timeout_s=UseDefault.VALUE,
              **request_fields) -> UnaryResponse:
-        responses: queue.SimpleQueue = queue.SimpleQueue()
 
-        def enqueue_response(_, status: Optional[Status], payload: Any):
-            assert isinstance(status, Status)
-            responses.put(UnaryResponse(status, payload))
-
-        self.reinvoke(enqueue_response, _rpc_request_proto, **request_fields)
+        handler = _UnaryResponseHandler(self._rpc)  # pylint: disable=protected-access
+        self.invoke(
+            self.method.get_request(_rpc_request_proto, request_fields),
+            handler.on_response, handler.on_completion, handler.on_error)
 
         if pw_rpc_timeout_s is UseDefault.VALUE:
             pw_rpc_timeout_s = self.default_timeout_s
 
-        try:
-            return responses.get(timeout=pw_rpc_timeout_s)
-        except queue.Empty:
-            pass
-
-        raise RpcTimeout(self._rpc, pw_rpc_timeout_s)  # pylint: disable=protected-access
+        return handler.wait(pw_rpc_timeout_s)
 
     _update_function_signature(method, call)
 
@@ -289,10 +349,9 @@
                               default_timeout)
 
 
-def server_streaming_method_client(client_impl: 'Impl',
-                                   rpcs: client.PendingRpcs, channel: Channel,
-                                   method: Method,
-                                   default_timeout: Optional[float]):
+def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
+                                    channel: Channel, method: Method,
+                                    default_timeout: Optional[float]):
     """Creates an object used to call a server streaming method."""
     def call(self: _MethodClient,
              _rpc_request_proto=None,
@@ -300,9 +359,11 @@
              pw_rpc_timeout_s=UseDefault.VALUE,
              **request_fields) -> StreamingResponses:
         responses: queue.SimpleQueue = queue.SimpleQueue()
-        self.reinvoke(
-            lambda _, status, payload: responses.put((status, payload)),
-            _rpc_request_proto, **request_fields)
+        self.invoke(
+            self.method.get_request(_rpc_request_proto, request_fields),
+            lambda _, response: responses.put(response),
+            lambda _, status: responses.put(status),
+            lambda rpc, status: responses.put(RpcError(rpc, status)))
         return StreamingResponses(self, responses, pw_rpc_timeout_s)
 
     _update_function_signature(method, call)
@@ -320,7 +381,12 @@
     def __call__(self):
         raise NotImplementedError
 
-    def invoke(self, callback: Callback, _request=None, **request_fields):
+    def invoke(self,
+               request: Any,
+               response: ResponseCallback = _default_response,
+               completion: CompletionCallback = _default_completion,
+               error: ErrorCallback = _default_error,
+               override_pending: bool = True) -> _AsyncCall:
         raise NotImplementedError
 
 
@@ -328,15 +394,21 @@
     def __call__(self):
         raise NotImplementedError
 
-    def invoke(self, callback: Callback, _request=None, **request_fields):
+    def invoke(self,
+               request: Any,
+               response: ResponseCallback = _default_response,
+               completion: CompletionCallback = _default_completion,
+               error: ErrorCallback = _default_error,
+               override_pending: bool = True) -> _AsyncCall:
         raise NotImplementedError
 
 
 class Impl(client.ClientImpl):
-    """Callback-based client.ClientImpl."""
+    """Callback-based ClientImpl."""
     def __init__(self,
                  default_unary_timeout_s: Optional[float] = 1.0,
                  default_stream_timeout_s: Optional[float] = 1.0):
+        super().__init__()
         self._default_unary_timeout_s = default_unary_timeout_s
         self._default_stream_timeout_s = default_stream_timeout_s
 
@@ -348,37 +420,37 @@
     def default_stream_timeout_s(self) -> Optional[float]:
         return self._default_stream_timeout_s
 
-    def method_client(self, rpcs: client.PendingRpcs, channel: Channel,
-                      method: Method) -> _MethodClient:
+    def method_client(self, channel: Channel, method: Method) -> _MethodClient:
         """Returns an object that invokes a method using the given chanel."""
 
         if method.type is Method.Type.UNARY:
-            return unary_method_client(self, rpcs, channel, method,
-                                       self.default_unary_timeout_s)
+            return _unary_method_client(self, self.rpcs, channel, method,
+                                        self.default_unary_timeout_s)
 
         if method.type is Method.Type.SERVER_STREAMING:
-            return server_streaming_method_client(
-                self, rpcs, channel, method, self.default_stream_timeout_s)
+            return _server_streaming_method_client(
+                self, self.rpcs, channel, method,
+                self.default_stream_timeout_s)
 
         if method.type is Method.Type.CLIENT_STREAMING:
-            return ClientStreamingMethodClient(self, rpcs, channel, method,
+            return ClientStreamingMethodClient(self, self.rpcs, channel,
+                                               method,
                                                self.default_unary_timeout_s)
 
         if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
             return BidirectionalStreamingMethodClient(
-                self, rpcs, channel, method, self.default_stream_timeout_s)
+                self, self.rpcs, channel, method,
+                self.default_stream_timeout_s)
 
         raise AssertionError(f'Unknown method type {method.type}')
 
-    def process_response(self,
-                         rpcs: client.PendingRpcs,
-                         rpc: client.PendingRpc,
-                         context,
-                         status: Optional[Status],
-                         payload,
-                         *,
-                         args: tuple = (),
-                         kwargs: dict = None) -> None:
+    def handle_response(self,
+                        rpc: PendingRpc,
+                        context,
+                        payload,
+                        *,
+                        args: tuple = (),
+                        kwargs: dict = None) -> None:
         """Invokes the callback associated with this RPC.
 
         Any additional positional and keyword args passed through
@@ -388,7 +460,40 @@
             kwargs = {}
 
         try:
-            context(rpc, status, payload, *args, **kwargs)
+            context.response(rpc, payload, *args, **kwargs)
         except:  # pylint: disable=bare-except
-            rpcs.send_cancel(rpc)
-            _LOG.exception('Callback %s for %s raised exception', context, rpc)
+            self.rpcs.send_cancel(rpc)
+            _LOG.exception('Response callback %s for %s raised exception',
+                           context.response, rpc)
+
+    def handle_completion(self,
+                          rpc: PendingRpc,
+                          context,
+                          status: Status,
+                          *,
+                          args: tuple = (),
+                          kwargs: dict = None):
+        if kwargs is None:
+            kwargs = {}
+
+        try:
+            context.completion(rpc, status, *args, **kwargs)
+        except:  # pylint: disable=bare-except
+            _LOG.exception('Completion callback %s for %s raised exception',
+                           context.completion, rpc)
+
+    def handle_error(self,
+                     rpc: PendingRpc,
+                     context,
+                     status: Status,
+                     *,
+                     args: tuple = (),
+                     kwargs: dict = None) -> None:
+        if kwargs is None:
+            kwargs = {}
+
+        try:
+            context.error(rpc, status, *args, **kwargs)
+        except:  # pylint: disable=bare-except
+            _LOG.exception('Error callback %s for %s raised exception',
+                           context.error, rpc)
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index 2fd7ee1..d82a4de 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -16,8 +16,8 @@
 import abc
 from dataclasses import dataclass
 import logging
-from typing import Collection, Dict, Iterable, Iterator, List, NamedTuple
-from typing import Optional
+from typing import (Any, Collection, Dict, Iterable, Iterator, List,
+                    NamedTuple, Optional)
 
 from google.protobuf.message import DecodeError
 from pw_rpc_protos.internal.packet_pb2 import PacketType, RpcPacket
@@ -122,39 +122,73 @@
     This interface defines the semantics for invoking an RPC on a particular
     client.
     """
+    def __init__(self):
+        self.client: 'Client' = None
+        self.rpcs: PendingRpcs = None
+
     @abc.abstractmethod
-    def method_client(self, rpcs: PendingRpcs, channel: Channel,
-                      method: Method):
+    def method_client(self, channel: Channel, method: Method) -> Any:
         """Returns an object that invokes a method using the given channel."""
 
     @abc.abstractmethod
-    def process_response(self,
-                         rpcs: PendingRpcs,
-                         rpc: PendingRpc,
-                         context,
-                         status: Optional[Status],
-                         payload,
-                         *,
-                         args: tuple = (),
-                         kwargs: dict = None) -> None:
-        """Processes a response from the RPC server.
+    def handle_response(self,
+                        rpc: PendingRpc,
+                        context: Any,
+                        payload: Any,
+                        *,
+                        args: tuple = (),
+                        kwargs: dict = None) -> Any:
+        """Handles a response from the RPC server.
 
         Args:
-          rpcs: The PendingRpcs object used by the client.
           rpc: Information about the pending RPC
           context: Arbitrary context object associated with the pending RPC
-          status: If set, this is the last packet for this RPC. None otherwise.
-          payload: A protobuf message, if present. None otherwise.
+          payload: A protobuf message
+          args, kwargs: Arbitrary arguments passed to the ClientImpl
+        """
+
+    @abc.abstractmethod
+    def handle_completion(self,
+                          rpc: PendingRpc,
+                          context: Any,
+                          status: Status,
+                          *,
+                          args: tuple = (),
+                          kwargs: dict = None) -> Any:
+        """Handles the successful completion of an RPC.
+
+        Args:
+          rpc: Information about the pending RPC
+          context: Arbitrary context object associated with the pending RPC
+          status: Status returned from the RPC
+          args, kwargs: Arbitrary arguments passed to the ClientImpl
+        """
+
+    @abc.abstractmethod
+    def handle_error(self,
+                     rpc: PendingRpc,
+                     context,
+                     status: Status,
+                     *,
+                     args: tuple = (),
+                     kwargs: dict = None):
+        """Handles the abnormal termination of an RPC.
+
+        args:
+          rpc: Information about the pending RPC
+          context: Arbitrary context object associated with the pending RPC
+          status: which error occurred
+          args, kwargs: Arbitrary arguments passed to the ClientImpl
         """
 
 
 class ServiceClient(descriptors.ServiceAccessor):
     """Navigates the methods in a service provided by a ChannelClient."""
-    def __init__(self, rpcs: PendingRpcs, client_impl: ClientImpl,
-                 channel: Channel, service: Service):
+    def __init__(self, client_impl: ClientImpl, channel: Channel,
+                 service: Service):
         super().__init__(
             {
-                method: client_impl.method_client(rpcs, channel, method)
+                method: client_impl.method_client(channel, method)
                 for method in service.methods
             },
             as_attrs='members')
@@ -173,13 +207,11 @@
 
 class Services(descriptors.ServiceAccessor[ServiceClient]):
     """Navigates the services provided by a ChannelClient."""
-    def __init__(self, rpcs: PendingRpcs, client_impl, channel: Channel,
+    def __init__(self, client_impl, channel: Channel,
                  services: Collection[Service]):
         super().__init__(
-            {
-                s: ServiceClient(rpcs, client_impl, channel, s)
-                for s in services
-            },
+            {s: ServiceClient(client_impl, channel, s)
+             for s in services},
             as_attrs='packages')
 
         self._channel = channel
@@ -286,15 +318,15 @@
     def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
                  services: Iterable[Service]):
         self._impl = impl
+        self._impl.client = self
+        self._impl.rpcs = PendingRpcs()
 
         self.services = descriptors.Services(services)
 
-        self._rpcs = PendingRpcs()
-
         self._channels_by_id = {
-            channel.id: ChannelClient(
-                self, channel,
-                Services(self._rpcs, self._impl, channel, self.services))
+            channel.id:
+            ChannelClient(self, channel,
+                          Services(self._impl, channel, self.services))
             for channel in channels
         }
 
@@ -374,7 +406,7 @@
         payload = _decode_payload(rpc, packet)
 
         try:
-            context = self._rpcs.get_pending(rpc, status)
+            context = self._impl.rpcs.get_pending(rpc, status)
         except KeyError:
             channel_client.channel.output(  # type: ignore
                 packets.encode_client_error(packet,
@@ -383,18 +415,28 @@
             return Status.OK
 
         if packet.type == PacketType.SERVER_ERROR:
+            assert status is not None and not status.ok()
             _LOG.warning('%s: invocation failed with %s', rpc, status)
-
-            # Do not return yet -- call process_response so the ClientImpl can
-            # do any necessary cleanup.
-
-        self._impl.process_response(self._rpcs,
-                                    rpc,
+            self._impl.handle_error(rpc,
                                     context,
                                     status,
-                                    payload,
                                     args=impl_args,
                                     kwargs=impl_kwargs)
+            return Status.OK
+
+        if payload is not None:
+            self._impl.handle_response(rpc,
+                                       context,
+                                       payload,
+                                       args=impl_args,
+                                       kwargs=impl_kwargs)
+        if status is not None:
+            self._impl.handle_completion(rpc,
+                                         context,
+                                         status,
+                                         args=impl_args,
+                                         kwargs=impl_kwargs)
+
         return Status.OK
 
     def _look_up_service_and_method(