blob: 3ed3882056ea346c1bbcda9e408684b9a9488def [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""device module unit tests"""
from contextlib import contextmanager
import logging
import queue
import threading
import time
import unittest
from pw_hdlc.rpc import RpcClient, HdlcRpcClient, CancellableReader
class QueueFile:
"""A fake file object backed by a queue for testing."""
EOF = object()
def __init__(self):
# Operator puts; consumer gets
self._q = queue.Queue()
# Consumer side access only!
self._readbuf = b''
self._eof = False
###############
# Consumer side
def __enter__(self):
return self
def __exit__(self, *exc_info):
self.close()
def _read_from_buf(self, size: int) -> bytes:
data = self._readbuf[:size]
self._readbuf = self._readbuf[size:]
return data
def read(self, size: int = 1) -> bytes:
"""Reads data from the queue"""
# First try to get buffered data
data = self._read_from_buf(size)
assert len(data) <= size
size -= len(data)
# if size == 0:
if data:
return data
# No more data in the buffer
assert not self._readbuf
if self._eof:
return data # may be empty
# Not enough in the buffer; block on the queue
item = self._q.get()
# NOTE: We can't call Queue.task_done() here because the reader hasn't
# actually *acted* on the read item yet.
# Queued data
if isinstance(item, bytes):
self._readbuf = item
return self._read_from_buf(size)
# Queued exception
if isinstance(item, Exception):
raise item
# Report EOF
if item is self.EOF:
self._eof = True
return data # may be empty
raise Exception('unexpected item type')
def write(self, data: bytes) -> None:
pass
#####################
# Weird middle ground
# It is a violation of most file-like object APIs for one thread to call
# close() while another thread is calling read(). The behavior is
# undefined.
#
# - On Linux, close() may wake up a select(), leaving the caller with a bad
# file descriptor (which could get reused!)
# - Or the read() could continue to block indefinitely.
#
# We choose to cause a subsequent/parallel read to receive an exception.
def close(self) -> None:
self.cause_read_exc(Exception('closed'))
###############
# Operator side
def put_read_data(self, data: bytes) -> None:
self._q.put(data)
def cause_read_exc(self, exc: Exception) -> None:
self._q.put(exc)
def set_read_eof(self) -> None:
self._q.put(self.EOF)
def wait_for_drain(self, timeout=None) -> None:
"""Wait for the queue to drain (be fully consumed).
Args:
timeout: The maximum time (in seconds) to wait, or wait forever
if None.
Raises:
TimeoutError: If timeout is given and has elapsed.
"""
# It would be great to use Queue.join() here, but that requires the
# consumer to call Queue.task_done(), and we can't do that because
# the consumer of read() doesn't know anything about it.
# Instead, we poll. ¯\_(ツ)_/¯
start_time = time.time()
while not self._q.empty():
if timeout is not None:
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(f"Queue not empty after {elapsed} sec")
time.sleep(0.1)
class QueueFileTest(unittest.TestCase):
"""Test the QueueFile class"""
def test_read_data(self) -> None:
file = QueueFile()
file.put_read_data(b'hello')
self.assertEqual(file.read(5), b'hello')
def test_read_data_multi_read(self) -> None:
file = QueueFile()
file.put_read_data(b'helloworld')
self.assertEqual(file.read(5), b'hello')
self.assertEqual(file.read(5), b'world')
def test_read_data_multi_put(self) -> None:
file = QueueFile()
file.put_read_data(b'hello')
file.put_read_data(b'world')
self.assertEqual(file.read(5), b'hello')
self.assertEqual(file.read(5), b'world')
def test_read_eof(self) -> None:
file = QueueFile()
file.set_read_eof()
result = file.read(5)
self.assertEqual(result, b'')
def test_read_exception(self) -> None:
file = QueueFile()
message = 'test exception'
file.cause_read_exc(ValueError(message))
with self.assertRaisesRegex(ValueError, message):
file.read(5)
def test_wait_for_drain_works(self) -> None:
file = QueueFile()
file.put_read_data(b'hello')
file.read()
try:
# Timeout is arbitrary; will return immediately.
file.wait_for_drain(0.1)
except TimeoutError:
self.fail("wait_for_drain raised TimeoutError")
def test_wait_for_drain_raises(self) -> None:
file = QueueFile()
file.put_read_data(b'hello')
# don't read
with self.assertRaises(TimeoutError):
# Timeout is arbitrary; it will raise no matter what.
file.wait_for_drain(0.1)
class Sentinel:
def __repr__(self):
return 'Sentinel'
class _QueueReader(CancellableReader):
def cancel_read(self) -> None:
self._base_obj.close()
def _get_client(file) -> RpcClient:
return HdlcRpcClient(
_QueueReader(file),
paths_or_modules=[],
channels=[],
)
# This should take <10ms but we'll wait up to 1000x longer.
_QUEUE_DRAIN_TIMEOUT = 10.0
class HdlcRpcClientTest(unittest.TestCase):
"""Tests the pw_hdlc.rpc.HdlcRpcClient class."""
# NOTE: There is no test here for stream EOF because Serial.read()
# can return an empty result if configured with timeout != None.
# The reader thread will continue in this case.
def test_clean_close_after_stream_close(self) -> None:
"""Assert RpcClient closes cleanly when stream closes."""
# See b/293595266.
file = QueueFile()
with self.assert_no_hdlc_rpc_error_logs():
with file:
with _get_client(file):
# We want to make sure the reader thread is blocked on
# read() and doesn't exit immediately.
file.put_read_data(b'')
file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
# RpcClient.__exit__ calls stop() on the reader thread, but
# it is blocked on file.read().
# QueueFile.close() is called, triggering an exception in the
# blocking read() (by implementation choice). The reader should
# handle it by *not* logging it and exiting immediately.
self.assert_no_background_threads_running()
def test_device_handles_read_exception(self) -> None:
"""Assert RpcClient closes cleanly when read raises an exception."""
# See b/293595266.
file = QueueFile()
logger = logging.getLogger('pw_hdlc.rpc')
test_exc = Exception('boom')
with self.assertLogs(logger, level=logging.ERROR) as ctx:
with _get_client(file):
# Cause read() to raise an exception. The reader should
# handle it by logging it and exiting immediately.
file.cause_read_exc(test_exc)
file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
# Assert one exception was raised
self.assertEqual(len(ctx.records), 1)
rec = ctx.records[0]
self.assertIsNotNone(rec.exc_info)
assert rec.exc_info is not None # for mypy
self.assertEqual(rec.exc_info[1], test_exc)
self.assert_no_background_threads_running()
@contextmanager
def assert_no_hdlc_rpc_error_logs(self):
logger = logging.getLogger('pw_hdlc.rpc')
sentinel = Sentinel()
with self.assertLogs(logger, level=logging.ERROR) as ctx:
# TODO: b/294861320 - use assertNoLogs() in Python 3.10+
# We actually want to assert there are no errors, but
# TestCase.assertNoLogs() is not available until Python 3.10.
# So we log one error to keep the test from failing and manually
# inspect the list of captured records.
logger.error(sentinel)
yield ctx
self.assertEqual([record.msg for record in ctx.records], [sentinel])
def assert_no_background_threads_running(self):
self.assertEqual(threading.enumerate(), [threading.current_thread()])
if __name__ == '__main__':
unittest.main()