pw_unit_test: Handle bad first response; timeouts

- Check that the first message is a "test_case_start" message. This
  prevents undefined variable errors if the first message is dropped.
- Support providing a timeout to use for the pw.unit_test.Run RPC.

Change-Id: I3f1df140cb07097969c6d7cf79b9c30a21172458
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/32140
Reviewed-by: Ewout van Bekkum <ewout@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index ead480b..a332458 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -56,12 +56,12 @@
 _LOG = logging.getLogger(__name__)
 
 
-class _UseDefault(enum.Enum):
+class UseDefault(enum.Enum):
     """Marker for args that should use a default value, when None is valid."""
     VALUE = 0
 
 
-_OptionalTimeout = Union[_UseDefault, float, None]
+OptionalTimeout = Union[UseDefault, float, None]
 
 Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
 
@@ -162,12 +162,12 @@
     """Used to iterate over a queue.SimpleQueue."""
     def __init__(self, method_client: _MethodClient,
                  responses: queue.SimpleQueue,
-                 default_timeout_s: _OptionalTimeout):
+                 default_timeout_s: OptionalTimeout):
         self._method_client = method_client
         self._queue = responses
         self.status: Optional[Status] = None
 
-        if default_timeout_s is _UseDefault.VALUE:
+        if default_timeout_s is UseDefault.VALUE:
             self.default_timeout_s = self._method_client.default_timeout_s
         else:
             self.default_timeout_s = default_timeout_s
@@ -179,13 +179,13 @@
     def responses(self,
                   *,
                   block: bool = True,
-                  timeout_s: _OptionalTimeout = _UseDefault.VALUE):
+                  timeout_s: OptionalTimeout = UseDefault.VALUE):
         """Returns an iterator of stream responses.
 
         Args:
           timeout_s: timeout in seconds; None blocks indefinitely
         """
-        if timeout_s is _UseDefault.VALUE:
+        if timeout_s is UseDefault.VALUE:
             timeout_s = self.default_timeout_s
 
         try:
@@ -258,7 +258,7 @@
     def call(self: _MethodClient,
              _rpc_request_proto=None,
              *,
-             pw_rpc_timeout_s=_UseDefault.VALUE,
+             pw_rpc_timeout_s=UseDefault.VALUE,
              **request_fields) -> UnaryResponse:
         responses: queue.SimpleQueue = queue.SimpleQueue()
 
@@ -268,7 +268,7 @@
 
         self.reinvoke(enqueue_response, _rpc_request_proto, **request_fields)
 
-        if pw_rpc_timeout_s is _UseDefault.VALUE:
+        if pw_rpc_timeout_s is UseDefault.VALUE:
             pw_rpc_timeout_s = self.default_timeout_s
 
         try:
@@ -297,7 +297,7 @@
     def call(self: _MethodClient,
              _rpc_request_proto=None,
              *,
-             pw_rpc_timeout_s=_UseDefault.VALUE,
+             pw_rpc_timeout_s=UseDefault.VALUE,
              **request_fields) -> StreamingResponses:
         responses: queue.SimpleQueue = queue.SimpleQueue()
         self.reinvoke(
diff --git a/pw_unit_test/py/pw_unit_test/rpc.py b/pw_unit_test/py/pw_unit_test/rpc.py
index c2ec22b..5ec0cfd 100644
--- a/pw_unit_test/py/pw_unit_test/rpc.py
+++ b/pw_unit_test/py/pw_unit_test/rpc.py
@@ -19,6 +19,7 @@
 from typing import Iterable
 
 import pw_rpc.client
+from pw_rpc.callback_client import OptionalTimeout, UseDefault
 from pw_unit_test_proto import unit_test_pb2
 
 _LOG = logging.getLogger(__name__)
@@ -110,11 +111,11 @@
         log('        Actual: %s', expectation.evaluated_expression)
 
 
-def run_tests(
-    rpcs: pw_rpc.client.Services,
-    report_passed_expectations: bool = False,
-    event_handlers: Iterable[EventHandler] = (LoggingEventHandler(), )
-) -> bool:
+def run_tests(rpcs: pw_rpc.client.Services,
+              report_passed_expectations: bool = False,
+              event_handlers: Iterable[EventHandler] = (
+                  LoggingEventHandler(), ),
+              timeout_s: OptionalTimeout = UseDefault.VALUE) -> bool:
     """Runs unit tests on a device over Pigweed RPC.
 
     Calls each of the provided event handlers as test events occur, and returns
@@ -122,15 +123,27 @@
     """
     unit_test_service = rpcs.pw.unit_test.UnitTest  # type: ignore[attr-defined]
 
-    all_tests_passed = False
-    for response in unit_test_service.Run(
-            report_passed_expectations=report_passed_expectations):
-        if response.HasField('test_case_start'):
-            raw_test_case = response.test_case_start
-            current_test_case = TestCase(raw_test_case.suite_name,
-                                         raw_test_case.test_name,
-                                         raw_test_case.file_name)
+    test_responses = iter(
+        unit_test_service.Run(
+            report_passed_expectations=report_passed_expectations,
+            timeout_s=timeout_s))
 
+    # Read the first response, which must be a test_case_start message.
+    first_response = next(test_responses)
+    if not first_response.HasField('test_case_start'):
+        raise ValueError(
+            'Expected a "test_case_start" response from pw.unit_test.Run, '
+            'but received a different message type. A response may have been '
+            'dropped.')
+
+    raw_test_case = first_response.test_case_start
+    current_test_case = TestCase(raw_test_case.suite_name,
+                                 raw_test_case.test_name,
+                                 raw_test_case.file_name)
+
+    all_tests_passed = False
+
+    for response in test_responses:
         for event_handler in event_handlers:
             if response.HasField('test_run_start'):
                 event_handler.run_all_tests_start()