| # Copyright 2022 The Pigweed Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy of |
| # the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations under |
| # the License. |
| """Classes for read and write transfers.""" |
| |
| import abc |
| import asyncio |
| from dataclasses import dataclass |
| import logging |
| import math |
| import threading |
| from typing import Any, Callable, Optional |
| |
| from pw_status import Status |
| try: |
| from pw_transfer import transfer_pb2 |
| except ImportError: |
| # For the bazel build, which puts generated protos in a different location. |
| from pigweed.pw_transfer import transfer_pb2 # type: ignore |
| |
| _LOG = logging.getLogger(__package__) |
| |
| |
| @dataclass(frozen=True) |
| class ProgressStats: |
| bytes_sent: int |
| bytes_confirmed_received: int |
| total_size_bytes: Optional[int] |
| |
| def percent_received(self) -> float: |
| if self.total_size_bytes is None: |
| return math.nan |
| |
| return self.bytes_confirmed_received / self.total_size_bytes * 100 |
| |
| def __str__(self) -> str: |
| total = str( |
| self.total_size_bytes) if self.total_size_bytes else 'unknown' |
| return (f'{self.percent_received():5.1f}% ({self.bytes_sent} B sent, ' |
| f'{self.bytes_confirmed_received} B received of {total} B)') |
| |
| |
| ProgressCallback = Callable[[ProgressStats], Any] |
| |
| |
| class _Timer: |
| """A timer which invokes a callback after a certain timeout.""" |
| def __init__(self, timeout_s: float, callback: Callable[[], Any]): |
| self.timeout_s = timeout_s |
| self._callback = callback |
| self._task: Optional[asyncio.Task[Any]] = None |
| |
| def start(self, timeout_s: float = None) -> None: |
| """Starts a new timer. |
| |
| If a timer is already running, it is stopped and a new timer started. |
| This can be used to implement watchdog-like behavior, where a callback |
| is invoked after some time without a kick. |
| """ |
| self.stop() |
| timeout_s = self.timeout_s if timeout_s is None else timeout_s |
| self._task = asyncio.create_task(self._run(timeout_s)) |
| |
| def stop(self) -> None: |
| """Terminates a running timer.""" |
| if self._task is not None: |
| self._task.cancel() |
| self._task = None |
| |
| async def _run(self, timeout_s: float) -> None: |
| await asyncio.sleep(timeout_s) |
| self._task = None |
| self._callback() |
| |
| |
| class Transfer(abc.ABC): |
| """A client-side data transfer through a Manager. |
| |
| Subclasses are responsible for implementing all of the logic for their type |
| of transfer, receiving messages from the server and sending the appropriate |
| messages in response. |
| """ |
| def __init__(self, |
| session_id: int, |
| send_chunk: Callable[[transfer_pb2.Chunk], None], |
| end_transfer: Callable[['Transfer'], None], |
| response_timeout_s: float, |
| initial_response_timeout_s: float, |
| max_retries: int, |
| progress_callback: ProgressCallback = None): |
| self.id = session_id |
| self.status = Status.OK |
| self.done = threading.Event() |
| |
| self._send_chunk = send_chunk |
| self._end_transfer = end_transfer |
| |
| self._retries = 0 |
| self._max_retries = max_retries |
| self._response_timer = _Timer(response_timeout_s, self._on_timeout) |
| self._initial_response_timeout_s = initial_response_timeout_s |
| |
| self._progress_callback = progress_callback |
| |
| async def begin(self) -> None: |
| """Sends the initial chunk of the transfer.""" |
| self._send_chunk(self._initial_chunk()) |
| self._response_timer.start(self._initial_response_timeout_s) |
| |
| @property |
| @abc.abstractmethod |
| def data(self) -> bytes: |
| """Returns the data read or written in this transfer.""" |
| |
| @abc.abstractmethod |
| def _initial_chunk(self) -> transfer_pb2.Chunk: |
| """Returns the initial chunk to notify the sever of the transfer.""" |
| |
| async def handle_chunk(self, chunk: transfer_pb2.Chunk) -> None: |
| """Processes an incoming chunk from the server. |
| |
| Handles terminating chunks (i.e. those with a status) and forwards |
| non-terminating chunks to handle_data_chunk. |
| """ |
| self._response_timer.stop() |
| self._retries = 0 # Received data from service, so reset the retries. |
| |
| _LOG.debug('Received chunk\n%s', str(chunk).rstrip()) |
| |
| # Status chunks are only used to terminate a transfer. They do not |
| # contain any data that requires processing. |
| if chunk.HasField('status'): |
| self.finish(Status(chunk.status)) |
| return |
| |
| await self._handle_data_chunk(chunk) |
| |
| # Start the timeout for the server to send a chunk in response. |
| self._response_timer.start() |
| |
| @abc.abstractmethod |
| async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None: |
| """Handles a chunk that contains or requests data.""" |
| |
| @abc.abstractmethod |
| def _retry_after_timeout(self) -> None: |
| """Retries after a timeout occurs.""" |
| |
| def _on_timeout(self) -> None: |
| """Handles a timeout while waiting for a chunk.""" |
| if self.done.is_set(): |
| return |
| |
| self._retries += 1 |
| if self._retries > self._max_retries: |
| self.finish(Status.DEADLINE_EXCEEDED) |
| return |
| |
| _LOG.debug('Received no responses for %.3fs; retrying %d/%d', |
| self._response_timer.timeout_s, self._retries, |
| self._max_retries) |
| self._retry_after_timeout() |
| self._response_timer.start() |
| |
| def finish(self, status: Status, skip_callback: bool = False) -> None: |
| """Ends the transfer with the specified status.""" |
| self._response_timer.stop() |
| self.status = status |
| |
| if status.ok(): |
| total_size = len(self.data) |
| self._update_progress(total_size, total_size, total_size) |
| |
| if not skip_callback: |
| self._end_transfer(self) |
| |
| # Set done last so that the transfer has been fully cleaned up. |
| self.done.set() |
| |
| def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int, |
| total_size_bytes: Optional[int]) -> None: |
| """Invokes the provided progress callback, if any, with the progress.""" |
| |
| stats = ProgressStats(bytes_sent, bytes_confirmed_received, |
| total_size_bytes) |
| _LOG.debug('Transfer %d progress: %s', self.id, stats) |
| |
| if self._progress_callback: |
| self._progress_callback(stats) |
| |
| def _send_error(self, error: Status) -> None: |
| """Sends an error chunk to the server and finishes the transfer.""" |
| self._send_chunk( |
| transfer_pb2.Chunk(transfer_id=self.id, |
| status=error.value, |
| type=transfer_pb2.Chunk.Type.COMPLETION)) |
| self.finish(error) |
| |
| |
| class WriteTransfer(Transfer): |
| """A client -> server write transfer.""" |
| def __init__( |
| self, |
| session_id: int, |
| data: bytes, |
| send_chunk: Callable[[transfer_pb2.Chunk], None], |
| end_transfer: Callable[[Transfer], None], |
| response_timeout_s: float, |
| initial_response_timeout_s: float, |
| max_retries: int, |
| progress_callback: ProgressCallback = None, |
| ): |
| super().__init__(session_id, send_chunk, end_transfer, |
| response_timeout_s, initial_response_timeout_s, |
| max_retries, progress_callback) |
| self._data = data |
| |
| # Guard this class with a lock since a transfer parameters update might |
| # arrive while responding to a prior update. |
| self._lock = asyncio.Lock() |
| self._offset = 0 |
| self._window_end_offset = 0 |
| self._max_chunk_size = 0 |
| self._chunk_delay_us: Optional[int] = None |
| |
| # The window ID increments for each parameters update. |
| self._window_id = 0 |
| |
| self._last_chunk = self._initial_chunk() |
| |
| @property |
| def data(self) -> bytes: |
| return self._data |
| |
| def _initial_chunk(self) -> transfer_pb2.Chunk: |
| # TODO(frolv): session_id should not be set here, but assigned by the |
| # server during an initial handshake. |
| return transfer_pb2.Chunk(transfer_id=self.id, |
| resource_id=self.id, |
| type=transfer_pb2.Chunk.Type.START) |
| |
| async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None: |
| """Processes an incoming chunk from the server. |
| |
| In a write transfer, the server only sends transfer parameter updates |
| to the client. When a message is received, update local parameters and |
| send data accordingly. |
| """ |
| |
| async with self._lock: |
| self._window_id += 1 |
| window_id = self._window_id |
| |
| if not self._handle_parameters_update(chunk): |
| return |
| |
| bytes_acknowledged = chunk.offset |
| |
| while True: |
| if self._chunk_delay_us: |
| await asyncio.sleep(self._chunk_delay_us / 1e6) |
| |
| async with self._lock: |
| if self.done.is_set(): |
| return |
| |
| if window_id != self._window_id: |
| _LOG.debug('Transfer %d: Skipping stale window', self.id) |
| return |
| |
| write_chunk = self._next_chunk() |
| self._offset += len(write_chunk.data) |
| sent_requested_bytes = self._offset == self._window_end_offset |
| |
| self._send_chunk(write_chunk) |
| |
| self._update_progress(self._offset, bytes_acknowledged, |
| len(self.data)) |
| |
| if sent_requested_bytes: |
| break |
| |
| self._last_chunk = write_chunk |
| |
| def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool: |
| """Updates transfer state based on a transfer parameters update.""" |
| |
| retransmit = True |
| if chunk.HasField('type'): |
| retransmit = (chunk.type |
| == transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT |
| or chunk.type == transfer_pb2.Chunk.Type.START) |
| |
| if chunk.offset > len(self.data): |
| # Bad offset; terminate the transfer. |
| _LOG.error( |
| 'Transfer %d: server requested invalid offset %d (size %d)', |
| self.id, chunk.offset, len(self.data)) |
| |
| self._send_error(Status.OUT_OF_RANGE) |
| return False |
| |
| if chunk.pending_bytes == 0: |
| _LOG.error( |
| 'Transfer %d: service requested 0 bytes (invalid); aborting', |
| self.id) |
| self._send_error(Status.INTERNAL) |
| return False |
| |
| if retransmit: |
| # Check whether the client has sent a previous data offset, which |
| # indicates that some chunks were lost in transmission. |
| if chunk.offset < self._offset: |
| _LOG.debug('Write transfer %d rolling back: offset %d from %d', |
| self.id, chunk.offset, self._offset) |
| |
| self._offset = chunk.offset |
| |
| # Retransmit is the default behavior for older versions of the |
| # transfer protocol. The window_end_offset field is not guaranteed |
| # to be set in these version, so it must be calculated. |
| max_bytes_to_send = min(chunk.pending_bytes, |
| len(self.data) - self._offset) |
| self._window_end_offset = self._offset + max_bytes_to_send |
| else: |
| assert chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE |
| |
| # Extend the window to the new end offset specified by the server. |
| self._window_end_offset = min(chunk.window_end_offset, |
| len(self.data)) |
| |
| if chunk.HasField('max_chunk_size_bytes'): |
| self._max_chunk_size = chunk.max_chunk_size_bytes |
| |
| if chunk.HasField('min_delay_microseconds'): |
| self._chunk_delay_us = chunk.min_delay_microseconds |
| |
| return True |
| |
| def _retry_after_timeout(self) -> None: |
| self._send_chunk(self._last_chunk) |
| |
| def _next_chunk(self) -> transfer_pb2.Chunk: |
| """Returns the next Chunk message to send in the data transfer.""" |
| chunk = transfer_pb2.Chunk(transfer_id=self.id, |
| offset=self._offset, |
| type=transfer_pb2.Chunk.Type.DATA) |
| max_bytes_in_chunk = min(self._max_chunk_size, |
| self._window_end_offset - self._offset) |
| |
| chunk.data = self.data[self._offset:self._offset + max_bytes_in_chunk] |
| |
| # Mark the final chunk of the transfer. |
| if len(self.data) - self._offset <= max_bytes_in_chunk: |
| chunk.remaining_bytes = 0 |
| |
| return chunk |
| |
| |
| class ReadTransfer(Transfer): |
| """A client <- server read transfer. |
| |
| Although Python can effectively handle an unlimited transfer window, this |
| client sets a conservative window and chunk size to avoid overloading the |
| device. These are configurable in the constructor. |
| """ |
| |
| # The fractional position within a window at which a receive transfer should |
| # extend its window size to minimize the amount of time the transmitter |
| # spends blocked. |
| # |
| # For example, a divisor of 2 will extend the window when half of the |
| # requested data has been received, a divisor of three will extend at a |
| # third of the window, and so on. |
| EXTEND_WINDOW_DIVISOR = 2 |
| |
| def __init__( # pylint: disable=too-many-arguments |
| self, |
| session_id: int, |
| send_chunk: Callable[[transfer_pb2.Chunk], None], |
| end_transfer: Callable[[Transfer], None], |
| response_timeout_s: float, |
| initial_response_timeout_s: float, |
| max_retries: int, |
| max_bytes_to_receive: int = 8192, |
| max_chunk_size: int = 1024, |
| chunk_delay_us: int = None, |
| progress_callback: ProgressCallback = None): |
| super().__init__(session_id, send_chunk, end_transfer, |
| response_timeout_s, initial_response_timeout_s, |
| max_retries, progress_callback) |
| self._max_bytes_to_receive = max_bytes_to_receive |
| self._max_chunk_size = max_chunk_size |
| self._chunk_delay_us = chunk_delay_us |
| |
| self._remaining_transfer_size: Optional[int] = None |
| self._data = bytearray() |
| self._offset = 0 |
| self._pending_bytes = max_bytes_to_receive |
| self._window_end_offset = max_bytes_to_receive |
| |
| @property |
| def data(self) -> bytes: |
| """Returns an immutable copy of the data that has been read.""" |
| return bytes(self._data) |
| |
| def _initial_chunk(self) -> transfer_pb2.Chunk: |
| return self._transfer_parameters(transfer_pb2.Chunk.Type.START) |
| |
| async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None: |
| """Processes an incoming chunk from the server. |
| |
| In a read transfer, the client receives data chunks from the server. |
| Once all pending data is received, the transfer parameters are updated. |
| """ |
| |
| if chunk.offset != self._offset: |
| # Initially, the transfer service only supports in-order transfers. |
| # If data is received out of order, request that the server |
| # retransmit from the previous offset. |
| self._send_chunk( |
| self._transfer_parameters( |
| transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)) |
| return |
| |
| self._data += chunk.data |
| self._pending_bytes -= len(chunk.data) |
| self._offset += len(chunk.data) |
| |
| if chunk.HasField('remaining_bytes'): |
| if chunk.remaining_bytes == 0: |
| # No more data to read. Acknowledge receipt and finish. |
| self._send_chunk( |
| transfer_pb2.Chunk( |
| transfer_id=self.id, |
| status=Status.OK.value, |
| type=transfer_pb2.Chunk.Type.COMPLETION)) |
| self.finish(Status.OK) |
| return |
| |
| # The server may indicate if the amount of remaining data is known. |
| self._remaining_transfer_size = chunk.remaining_bytes |
| elif self._remaining_transfer_size is not None: |
| # Update the remaining transfer size, if it is known. |
| self._remaining_transfer_size -= len(chunk.data) |
| |
| # If the transfer size drops to zero, the estimate was inaccurate. |
| if self._remaining_transfer_size <= 0: |
| self._remaining_transfer_size = None |
| |
| total_size = None if self._remaining_transfer_size is None else ( |
| self._remaining_transfer_size + self._offset) |
| self._update_progress(self._offset, self._offset, total_size) |
| |
| if chunk.window_end_offset != 0: |
| if chunk.window_end_offset < self._offset: |
| _LOG.error( |
| 'Transfer %d: transmitter sent invalid earlier end offset ' |
| '%d (receiver offset %d)', self.id, |
| chunk.window_end_offset, self._offset) |
| self._send_error(Status.INTERNAL) |
| return |
| |
| if chunk.window_end_offset > self._window_end_offset: |
| _LOG.error( |
| 'Transfer %d: transmitter sent invalid later end offset ' |
| '%d (receiver end offset %d)', self.id, |
| chunk.window_end_offset, self._window_end_offset) |
| self._send_error(Status.INTERNAL) |
| return |
| |
| self._window_end_offset = chunk.window_end_offset |
| self._pending_bytes -= chunk.window_end_offset - self._offset |
| |
| remaining_window_size = self._window_end_offset - self._offset |
| extend_window = (remaining_window_size <= self._max_bytes_to_receive / |
| ReadTransfer.EXTEND_WINDOW_DIVISOR) |
| |
| if self._pending_bytes == 0: |
| # All pending data was received. Send out a new parameters chunk for |
| # the next block. |
| self._send_chunk( |
| self._transfer_parameters( |
| transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)) |
| elif extend_window: |
| self._send_chunk( |
| self._transfer_parameters( |
| transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE)) |
| |
| def _retry_after_timeout(self) -> None: |
| self._send_chunk( |
| self._transfer_parameters( |
| transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)) |
| |
| def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk: |
| """Sends an updated transfer parameters chunk to the server.""" |
| |
| self._pending_bytes = self._max_bytes_to_receive |
| self._window_end_offset = self._offset + self._max_bytes_to_receive |
| |
| chunk = transfer_pb2.Chunk(transfer_id=self.id, |
| pending_bytes=self._pending_bytes, |
| window_end_offset=self._window_end_offset, |
| max_chunk_size_bytes=self._max_chunk_size, |
| offset=self._offset, |
| type=chunk_type) |
| |
| if self._chunk_delay_us: |
| chunk.min_delay_microseconds = self._chunk_delay_us |
| |
| return chunk |