pw_rpc: Timeouts in Python pw_rpc client

The callback_client implementation allows specifying a default timeout
for unary and server streaming RPCs. The timeout can be overridden for
an individual RPC call. If no response is received by the timeout, an
RpcTimeout exception is raised.

Change-Id: I932bb6ebbf3bf157adb40bc0536a529c8411c221
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/31140
Reviewed-by: Ewout van Bekkum <ewout@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/py/callback_client_test.py b/pw_rpc/py/callback_client_test.py
index 18cd66a..2b49545 100755
--- a/pw_rpc/py/callback_client_test.py
+++ b/pw_rpc/py/callback_client_test.py
@@ -65,6 +65,7 @@
         self._client = client.Client.from_modules(
             callback_client.Impl(), [client.Channel(1, self._handle_request)],
             self._protos.modules())
+        self._service = self._client.channel(1).rpcs.pw.test1.PublicService
 
         self._last_request: packet_pb2.RpcPacket = None
         self._next_packets: List[Tuple[bytes, Status]] = []
@@ -135,14 +136,13 @@
         return message
 
     def test_invoke_unary_rpc(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
-        method = stub.SomeUnary.method
+        method = self._service.SomeUnary.method
 
         for _ in range(3):
             self._enqueue_response(1, method, Status.ABORTED,
                                    method.response_type(payload='0_o'))
 
-            status, response = stub.SomeUnary(magic_number=6)
+            status, response = self._service.SomeUnary(magic_number=6)
 
             self.assertEqual(
                 6,
@@ -152,18 +152,17 @@
             self.assertEqual('0_o', response.payload)
 
     def test_invoke_unary_rpc_with_callback(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
-        method = stub.SomeUnary.method
+        method = self._service.SomeUnary.method
 
         for _ in range(3):
             self._enqueue_response(1, method, Status.ABORTED,
                                    method.response_type(payload='0_o'))
 
             callback = mock.Mock()
-            stub.SomeUnary.invoke(callback, magic_number=5)
+            self._service.SomeUnary.invoke(callback, magic_number=5)
 
             callback.assert_called_once_with(
-                _rpc(stub.SomeUnary), Status.ABORTED,
+                _rpc(self._service.SomeUnary), Status.ABORTED,
                 method.response_type(payload='0_o'))
 
             self.assertEqual(
@@ -171,7 +170,7 @@
                 self._sent_payload(method.request_type).magic_number)
 
     def test_invoke_unary_rpc_callback_errors_suppressed(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService.SomeUnary
+        stub = self._service.SomeUnary
 
         self._enqueue_response(1, stub.method)
         exception_msg = 'YOU BROKE IT O-]-<'
@@ -187,18 +186,17 @@
         self.assertIs(status, Status.UNKNOWN)
 
     def test_invoke_unary_rpc_with_callback_cancel(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
         callback = mock.Mock()
 
         for _ in range(3):
-            call = stub.SomeUnary.invoke(callback, magic_number=55)
+            call = self._service.SomeUnary.invoke(callback, magic_number=55)
 
             self.assertIsNotNone(self._last_request)
             self._last_request = None
 
             # Try to invoke the RPC again before cancelling, which is an error.
             with self.assertRaises(client.Error):
-                stub.SomeUnary.invoke(callback, magic_number=56)
+                self._service.SomeUnary.invoke(callback, magic_number=56)
 
             self.assertTrue(call.cancel())
             self.assertFalse(call.cancel())  # Already cancelled, returns False
@@ -209,19 +207,17 @@
         callback.assert_not_called()
 
     def test_reinvoke_unary_rpc(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
         callback = mock.Mock()
 
         # The reinvoke method ignores pending rpcs, so can be called repeatedly.
         for _ in range(3):
             self._last_request = None
-            stub.SomeUnary.reinvoke(callback, magic_number=55)
+            self._service.SomeUnary.reinvoke(callback, magic_number=55)
             self.assertEqual(self._last_request.type,
                              packets.PacketType.REQUEST)
 
     def test_invoke_server_streaming(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
-        method = stub.SomeServerStreaming.method
+        method = self._service.SomeServerStreaming.method
 
         rep1 = method.response_type(payload='!!!')
         rep2 = method.response_type(payload='?')
@@ -231,16 +227,16 @@
             self._enqueue_response(1, method, response=rep2)
             self._enqueue_stream_end(1, method, Status.ABORTED)
 
-            self.assertEqual([rep1, rep2],
-                             list(stub.SomeServerStreaming(magic_number=4)))
+            self.assertEqual(
+                [rep1, rep2],
+                list(self._service.SomeServerStreaming(magic_number=4)))
 
             self.assertEqual(
                 4,
                 self._sent_payload(method.request_type).magic_number)
 
     def test_invoke_server_streaming_with_callback(self):
-        stub = self._client.channel(1).rpcs.pw.test1.PublicService
-        method = stub.SomeServerStreaming.method
+        method = self._service.SomeServerStreaming.method
 
         rep1 = method.response_type(payload='!!!')
         rep2 = method.response_type(payload='?')
@@ -251,9 +247,9 @@
             self._enqueue_stream_end(1, method, Status.ABORTED)
 
             callback = mock.Mock()
-            stub.SomeServerStreaming.invoke(callback, magic_number=3)
+            self._service.SomeServerStreaming.invoke(callback, magic_number=3)
 
-            rpc = _rpc(stub.SomeServerStreaming)
+            rpc = _rpc(self._service.SomeServerStreaming)
             callback.assert_has_calls([
                 mock.call(rpc, None, method.response_type(payload='!!!')),
                 mock.call(rpc, None, method.response_type(payload='?')),
@@ -265,8 +261,7 @@
                 self._sent_payload(method.request_type).magic_number)
 
     def test_invoke_server_streaming_with_callback_cancel(self):
-        stub = self._client.channel(
-            1).rpcs.pw.test1.PublicService.SomeServerStreaming
+        stub = self._service.SomeServerStreaming
 
         resp = stub.method.response_type(payload='!!!')
         self._enqueue_response(1, stub.method, response=resp)
@@ -343,6 +338,46 @@
         rpc = self._client.channel(1).rpcs.pw.test1.PublicService.SomeUnary
         self.assertIn(rpc.method.full_name, rpc.help())
 
+    def test_default_timeouts_set_on_impl(self):
+        impl = callback_client.Impl(None, 1.5)
+
+        self.assertEqual(impl.default_unary_timeout_s, None)
+        self.assertEqual(impl.default_stream_timeout_s, 1.5)
+
+    def test_default_timeouts_set_for_all_rpcs(self):
+        rpc_client = client.Client.from_modules(callback_client.Impl(
+            99, 100), [client.Channel(1, lambda *a, **b: None)],
+                                                self._protos.modules())
+        rpcs = rpc_client.channel(1).rpcs
+
+        self.assertEqual(
+            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
+        self.assertEqual(
+            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
+            100)
+
+    def test_timeout_unary(self):
+        with self.assertRaises(callback_client.RpcTimeout):
+            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
+
+    def test_timeout_unary_set_default(self):
+        self._service.SomeUnary.default_timeout_s = 0.0001
+
+        with self.assertRaises(callback_client.RpcTimeout):
+            self._service.SomeUnary()
+
+    def test_timeout_server_streaming_iteration(self):
+        responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
+        with self.assertRaises(callback_client.RpcTimeout):
+            for _ in responses:
+                pass
+
+    def test_timeout_server_streaming_responses(self):
+        responses = self._service.SomeServerStreaming()
+        with self.assertRaises(callback_client.RpcTimeout):
+            for _ in responses.responses(timeout_s=0.0001):
+                pass
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index c8dd9a4..ead480b 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -40,11 +40,12 @@
 kwargs for the message fields (but not both).
 """
 
+import enum
 import inspect
 import logging
 import queue
 import textwrap
-from typing import Any, Callable, NamedTuple, Optional
+from typing import Any, Callable, NamedTuple, Union, Optional
 
 from pw_status import Status
 
@@ -54,16 +55,26 @@
 
 _LOG = logging.getLogger(__name__)
 
+
+class _UseDefault(enum.Enum):
+    """Marker for args that should use a default value, when None is valid."""
+    VALUE = 0
+
+
+_OptionalTimeout = Union[_UseDefault, float, None]
+
 Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
 
 
 class _MethodClient:
     """A method that can be invoked for a particular channel."""
     def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs,
-                 channel: Channel, method: Method):
+                 channel: Channel, method: Method,
+                 default_timeout_s: Optional[float]):
         self._impl = client_impl
         self._rpcs = rpcs
         self._rpc = client.PendingRpc(channel, method.service, method)
+        self.default_timeout_s: Optional[float] = default_timeout_s
 
     @property
     def channel(self) -> Channel:
@@ -120,6 +131,14 @@
             f'  Returns {annotation}.')
 
 
+class RpcTimeout(Exception):
+    def __init__(self, rpc: client.PendingRpc, timeout: Optional[float]):
+        super().__init__(
+            f'No response received for {rpc.method} after {timeout} s')
+        self.rpc = rpc
+        self.timeout = timeout
+
+
 class _AsyncCall:
     """Represents an ongoing callback-based call."""
 
@@ -141,21 +160,48 @@
 
 class StreamingResponses:
     """Used to iterate over a queue.SimpleQueue."""
-    def __init__(self, method: Method, responses: queue.SimpleQueue):
-        self.method = method
+    def __init__(self, method_client: _MethodClient,
+                 responses: queue.SimpleQueue,
+                 default_timeout_s: _OptionalTimeout):
+        self._method_client = method_client
         self._queue = responses
         self.status: Optional[Status] = None
 
-    def get(self, *, block: bool = True, timeout_s: float = None):
-        while True:
-            self.status, response = self._queue.get(block, timeout_s)
-            if self.status is not None:
-                return
+        if default_timeout_s is _UseDefault.VALUE:
+            self.default_timeout_s = self._method_client.default_timeout_s
+        else:
+            self.default_timeout_s = default_timeout_s
 
-            yield response
+    @property
+    def method(self) -> Method:
+        return self._method_client.method
+
+    def responses(self,
+                  *,
+                  block: bool = True,
+                  timeout_s: _OptionalTimeout = _UseDefault.VALUE):
+        """Returns an iterator of stream responses.
+
+        Args:
+          timeout_s: timeout in seconds; None blocks indefinitely
+        """
+        if timeout_s is _UseDefault.VALUE:
+            timeout_s = self.default_timeout_s
+
+        try:
+            while True:
+                self.status, response = self._queue.get(block, timeout_s)
+                if self.status is not None:
+                    return
+
+                yield response
+        except queue.Empty:
+            pass
+
+        raise RpcTimeout(self._method_client._rpc, timeout_s)  # pylint: disable=protected-access
 
     def __iter__(self):
-        return self.get()
+        return self.responses()
 
     def __repr__(self) -> str:
         return f'{type(self).__name__}({self.method})'
@@ -206,15 +252,31 @@
 
 
 def unary_method_client(client_impl: 'Impl', rpcs: client.PendingRpcs,
-                        channel: Channel, method: Method) -> _MethodClient:
+                        channel: Channel, method: Method,
+                        default_timeout: Optional[float]) -> _MethodClient:
     """Creates an object used to call a unary method."""
-    def call(self, _rpc_request_proto=None, **request_fields) -> UnaryResponse:
+    def call(self: _MethodClient,
+             _rpc_request_proto=None,
+             *,
+             pw_rpc_timeout_s=_UseDefault.VALUE,
+             **request_fields) -> UnaryResponse:
         responses: queue.SimpleQueue = queue.SimpleQueue()
-        self.reinvoke(
-            lambda _, status, payload: responses.put(
-                UnaryResponse(status, payload)), _rpc_request_proto,
-            **request_fields)
-        return responses.get()
+
+        def enqueue_response(_, status: Optional[Status], payload: Any):
+            assert isinstance(status, Status)
+            responses.put(UnaryResponse(status, payload))
+
+        self.reinvoke(enqueue_response, _rpc_request_proto, **request_fields)
+
+        if pw_rpc_timeout_s is _UseDefault.VALUE:
+            pw_rpc_timeout_s = self.default_timeout_s
+
+        try:
+            return responses.get(timeout=pw_rpc_timeout_s)
+        except queue.Empty:
+            pass
+
+        raise RpcTimeout(self._rpc, pw_rpc_timeout_s)  # pylint: disable=protected-access
 
     _update_function_signature(method, call)
 
@@ -223,21 +285,25 @@
     method_client_type = type(
         f'{method.name}_UnaryMethodClient', (_MethodClient, ),
         dict(__call__=call, __doc__=_method_client_docstring(method)))
-    return method_client_type(client_impl, rpcs, channel, method)
+    return method_client_type(client_impl, rpcs, channel, method,
+                              default_timeout)
 
 
 def server_streaming_method_client(client_impl: 'Impl',
                                    rpcs: client.PendingRpcs, channel: Channel,
-                                   method: Method):
+                                   method: Method,
+                                   default_timeout: Optional[float]):
     """Creates an object used to call a server streaming method."""
-    def call(self,
+    def call(self: _MethodClient,
              _rpc_request_proto=None,
+             *,
+             pw_rpc_timeout_s=_UseDefault.VALUE,
              **request_fields) -> StreamingResponses:
         responses: queue.SimpleQueue = queue.SimpleQueue()
         self.reinvoke(
             lambda _, status, payload: responses.put((status, payload)),
             _rpc_request_proto, **request_fields)
-        return StreamingResponses(method, responses)
+        return StreamingResponses(self, responses, pw_rpc_timeout_s)
 
     _update_function_signature(method, call)
 
@@ -246,7 +312,8 @@
     method_client_type = type(
         f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
         dict(__call__=call, __doc__=_method_client_docstring(method)))
-    return method_client_type(client_impl, rpcs, channel, method)
+    return method_client_type(client_impl, rpcs, channel, method,
+                              default_timeout)
 
 
 class ClientStreamingMethodClient(_MethodClient):
@@ -267,22 +334,39 @@
 
 class Impl(client.ClientImpl):
     """Callback-based client.ClientImpl."""
+    def __init__(self,
+                 default_unary_timeout_s: Optional[float] = 1.0,
+                 default_stream_timeout_s: Optional[float] = 1.0):
+        self._default_unary_timeout_s = default_unary_timeout_s
+        self._default_stream_timeout_s = default_stream_timeout_s
+
+    @property
+    def default_unary_timeout_s(self) -> Optional[float]:
+        return self._default_unary_timeout_s
+
+    @property
+    def default_stream_timeout_s(self) -> Optional[float]:
+        return self._default_stream_timeout_s
+
     def method_client(self, rpcs: client.PendingRpcs, channel: Channel,
                       method: Method) -> _MethodClient:
         """Returns an object that invokes a method using the given chanel."""
 
         if method.type is Method.Type.UNARY:
-            return unary_method_client(self, rpcs, channel, method)
+            return unary_method_client(self, rpcs, channel, method,
+                                       self.default_unary_timeout_s)
 
         if method.type is Method.Type.SERVER_STREAMING:
-            return server_streaming_method_client(self, rpcs, channel, method)
+            return server_streaming_method_client(
+                self, rpcs, channel, method, self.default_stream_timeout_s)
 
         if method.type is Method.Type.CLIENT_STREAMING:
-            return ClientStreamingMethodClient(self, rpcs, channel, method)
+            return ClientStreamingMethodClient(self, rpcs, channel, method,
+                                               self.default_unary_timeout_s)
 
         if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
-            return BidirectionalStreamingMethodClient(self, rpcs, channel,
-                                                      method)
+            return BidirectionalStreamingMethodClient(
+                self, rpcs, channel, method, self.default_stream_timeout_s)
 
         raise AssertionError(f'Unknown method type {method.type}')