pw_unit_test: Handle bad first response; timeouts
- Check that the first message is a "test_case_start" message. This
prevents undefined variable errors if the first message is dropped.
- Support providing a timeout to use for the pw.unit_test.Run RPC.
Change-Id: I3f1df140cb07097969c6d7cf79b9c30a21172458
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/32140
Reviewed-by: Ewout van Bekkum <ewout@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
index ead480b..a332458 100644
--- a/pw_rpc/py/pw_rpc/callback_client.py
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -56,12 +56,12 @@
_LOG = logging.getLogger(__name__)
-class _UseDefault(enum.Enum):
+class UseDefault(enum.Enum):
"""Marker for args that should use a default value, when None is valid."""
VALUE = 0
-_OptionalTimeout = Union[_UseDefault, float, None]
+OptionalTimeout = Union[UseDefault, float, None]
Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
@@ -162,12 +162,12 @@
"""Used to iterate over a queue.SimpleQueue."""
def __init__(self, method_client: _MethodClient,
responses: queue.SimpleQueue,
- default_timeout_s: _OptionalTimeout):
+ default_timeout_s: OptionalTimeout):
self._method_client = method_client
self._queue = responses
self.status: Optional[Status] = None
- if default_timeout_s is _UseDefault.VALUE:
+ if default_timeout_s is UseDefault.VALUE:
self.default_timeout_s = self._method_client.default_timeout_s
else:
self.default_timeout_s = default_timeout_s
@@ -179,13 +179,13 @@
def responses(self,
*,
block: bool = True,
- timeout_s: _OptionalTimeout = _UseDefault.VALUE):
+ timeout_s: OptionalTimeout = UseDefault.VALUE):
"""Returns an iterator of stream responses.
Args:
timeout_s: timeout in seconds; None blocks indefinitely
"""
- if timeout_s is _UseDefault.VALUE:
+ if timeout_s is UseDefault.VALUE:
timeout_s = self.default_timeout_s
try:
@@ -258,7 +258,7 @@
def call(self: _MethodClient,
_rpc_request_proto=None,
*,
- pw_rpc_timeout_s=_UseDefault.VALUE,
+ pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> UnaryResponse:
responses: queue.SimpleQueue = queue.SimpleQueue()
@@ -268,7 +268,7 @@
self.reinvoke(enqueue_response, _rpc_request_proto, **request_fields)
- if pw_rpc_timeout_s is _UseDefault.VALUE:
+ if pw_rpc_timeout_s is UseDefault.VALUE:
pw_rpc_timeout_s = self.default_timeout_s
try:
@@ -297,7 +297,7 @@
def call(self: _MethodClient,
_rpc_request_proto=None,
*,
- pw_rpc_timeout_s=_UseDefault.VALUE,
+ pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
self.reinvoke(
diff --git a/pw_unit_test/py/pw_unit_test/rpc.py b/pw_unit_test/py/pw_unit_test/rpc.py
index c2ec22b..5ec0cfd 100644
--- a/pw_unit_test/py/pw_unit_test/rpc.py
+++ b/pw_unit_test/py/pw_unit_test/rpc.py
@@ -19,6 +19,7 @@
from typing import Iterable
import pw_rpc.client
+from pw_rpc.callback_client import OptionalTimeout, UseDefault
from pw_unit_test_proto import unit_test_pb2
_LOG = logging.getLogger(__name__)
@@ -110,11 +111,11 @@
log(' Actual: %s', expectation.evaluated_expression)
-def run_tests(
- rpcs: pw_rpc.client.Services,
- report_passed_expectations: bool = False,
- event_handlers: Iterable[EventHandler] = (LoggingEventHandler(), )
-) -> bool:
+def run_tests(rpcs: pw_rpc.client.Services,
+ report_passed_expectations: bool = False,
+ event_handlers: Iterable[EventHandler] = (
+ LoggingEventHandler(), ),
+ timeout_s: OptionalTimeout = UseDefault.VALUE) -> bool:
"""Runs unit tests on a device over Pigweed RPC.
Calls each of the provided event handlers as test events occur, and returns
@@ -122,15 +123,27 @@
"""
unit_test_service = rpcs.pw.unit_test.UnitTest # type: ignore[attr-defined]
- all_tests_passed = False
- for response in unit_test_service.Run(
- report_passed_expectations=report_passed_expectations):
- if response.HasField('test_case_start'):
- raw_test_case = response.test_case_start
- current_test_case = TestCase(raw_test_case.suite_name,
- raw_test_case.test_name,
- raw_test_case.file_name)
+ test_responses = iter(
+ unit_test_service.Run(
+ report_passed_expectations=report_passed_expectations,
+ timeout_s=timeout_s))
+ # Read the first response, which must be a test_case_start message.
+ first_response = next(test_responses)
+ if not first_response.HasField('test_case_start'):
+ raise ValueError(
+ 'Expected a "test_case_start" response from pw.unit_test.Run, '
+ 'but received a different message type. A response may have been '
+ 'dropped.')
+
+ raw_test_case = first_response.test_case_start
+ current_test_case = TestCase(raw_test_case.suite_name,
+ raw_test_case.test_name,
+ raw_test_case.file_name)
+
+ all_tests_passed = False
+
+ for response in test_responses:
for event_handler in event_handlers:
if response.HasField('test_run_start'):
event_handler.run_all_tests_start()