blob: 53c2c39f83d6cfec50c72d714da1d1a66a0704f0 [file] [log] [blame]
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Classes for read and write transfers."""
import abc
import asyncio
from dataclasses import dataclass
import logging
import math
import threading
from typing import Any, Callable, Optional
from pw_status import Status
try:
from pw_transfer import transfer_pb2
except ImportError:
# For the bazel build, which puts generated protos in a different location.
from pigweed.pw_transfer import transfer_pb2 # type: ignore
_LOG = logging.getLogger(__package__)
@dataclass(frozen=True)
class ProgressStats:
bytes_sent: int
bytes_confirmed_received: int
total_size_bytes: Optional[int]
def percent_received(self) -> float:
if self.total_size_bytes is None:
return math.nan
return self.bytes_confirmed_received / self.total_size_bytes * 100
def __str__(self) -> str:
total = str(
self.total_size_bytes) if self.total_size_bytes else 'unknown'
return (f'{self.percent_received():5.1f}% ({self.bytes_sent} B sent, '
f'{self.bytes_confirmed_received} B received of {total} B)')
ProgressCallback = Callable[[ProgressStats], Any]
class _Timer:
"""A timer which invokes a callback after a certain timeout."""
def __init__(self, timeout_s: float, callback: Callable[[], Any]):
self.timeout_s = timeout_s
self._callback = callback
self._task: Optional[asyncio.Task[Any]] = None
def start(self, timeout_s: float = None) -> None:
"""Starts a new timer.
If a timer is already running, it is stopped and a new timer started.
This can be used to implement watchdog-like behavior, where a callback
is invoked after some time without a kick.
"""
self.stop()
timeout_s = self.timeout_s if timeout_s is None else timeout_s
self._task = asyncio.create_task(self._run(timeout_s))
def stop(self) -> None:
"""Terminates a running timer."""
if self._task is not None:
self._task.cancel()
self._task = None
async def _run(self, timeout_s: float) -> None:
await asyncio.sleep(timeout_s)
self._task = None
self._callback()
class Transfer(abc.ABC):
"""A client-side data transfer through a Manager.
Subclasses are responsible for implementing all of the logic for their type
of transfer, receiving messages from the server and sending the appropriate
messages in response.
"""
def __init__(self,
session_id: int,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[['Transfer'], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
progress_callback: ProgressCallback = None):
self.id = session_id
self.status = Status.OK
self.done = threading.Event()
self._send_chunk = send_chunk
self._end_transfer = end_transfer
self._retries = 0
self._max_retries = max_retries
self._response_timer = _Timer(response_timeout_s, self._on_timeout)
self._initial_response_timeout_s = initial_response_timeout_s
self._progress_callback = progress_callback
async def begin(self) -> None:
"""Sends the initial chunk of the transfer."""
self._send_chunk(self._initial_chunk())
self._response_timer.start(self._initial_response_timeout_s)
@property
@abc.abstractmethod
def data(self) -> bytes:
"""Returns the data read or written in this transfer."""
@abc.abstractmethod
def _initial_chunk(self) -> transfer_pb2.Chunk:
"""Returns the initial chunk to notify the sever of the transfer."""
async def handle_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
Handles terminating chunks (i.e. those with a status) and forwards
non-terminating chunks to handle_data_chunk.
"""
self._response_timer.stop()
self._retries = 0 # Received data from service, so reset the retries.
_LOG.debug('Received chunk\n%s', str(chunk).rstrip())
# Status chunks are only used to terminate a transfer. They do not
# contain any data that requires processing.
if chunk.HasField('status'):
self.finish(Status(chunk.status))
return
await self._handle_data_chunk(chunk)
# Start the timeout for the server to send a chunk in response.
self._response_timer.start()
@abc.abstractmethod
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Handles a chunk that contains or requests data."""
@abc.abstractmethod
def _retry_after_timeout(self) -> None:
"""Retries after a timeout occurs."""
def _on_timeout(self) -> None:
"""Handles a timeout while waiting for a chunk."""
if self.done.is_set():
return
self._retries += 1
if self._retries > self._max_retries:
self.finish(Status.DEADLINE_EXCEEDED)
return
_LOG.debug('Received no responses for %.3fs; retrying %d/%d',
self._response_timer.timeout_s, self._retries,
self._max_retries)
self._retry_after_timeout()
self._response_timer.start()
def finish(self, status: Status, skip_callback: bool = False) -> None:
"""Ends the transfer with the specified status."""
self._response_timer.stop()
self.status = status
if status.ok():
total_size = len(self.data)
self._update_progress(total_size, total_size, total_size)
if not skip_callback:
self._end_transfer(self)
# Set done last so that the transfer has been fully cleaned up.
self.done.set()
def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
total_size_bytes: Optional[int]) -> None:
"""Invokes the provided progress callback, if any, with the progress."""
stats = ProgressStats(bytes_sent, bytes_confirmed_received,
total_size_bytes)
_LOG.debug('Transfer %d progress: %s', self.id, stats)
if self._progress_callback:
self._progress_callback(stats)
def _send_error(self, error: Status) -> None:
"""Sends an error chunk to the server and finishes the transfer."""
self._send_chunk(
transfer_pb2.Chunk(
session_id=self.id,
status=error.value,
type=transfer_pb2.Chunk.Type.TRANSFER_COMPLETION))
self.finish(error)
class WriteTransfer(Transfer):
"""A client -> server write transfer."""
def __init__(
self,
session_id: int,
data: bytes,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
progress_callback: ProgressCallback = None,
):
super().__init__(session_id, send_chunk, end_transfer,
response_timeout_s, initial_response_timeout_s,
max_retries, progress_callback)
self._data = data
# Guard this class with a lock since a transfer parameters update might
# arrive while responding to a prior update.
self._lock = asyncio.Lock()
self._offset = 0
self._window_end_offset = 0
self._max_chunk_size = 0
self._chunk_delay_us: Optional[int] = None
# The window ID increments for each parameters update.
self._window_id = 0
self._last_chunk = self._initial_chunk()
@property
def data(self) -> bytes:
return self._data
def _initial_chunk(self) -> transfer_pb2.Chunk:
# TODO(frolv): session_id should not be set here, but assigned by the
# server during an initial handshake.
return transfer_pb2.Chunk(session_id=self.id,
resource_id=self.id,
type=transfer_pb2.Chunk.Type.TRANSFER_START)
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a write transfer, the server only sends transfer parameter updates
to the client. When a message is received, update local parameters and
send data accordingly.
"""
async with self._lock:
self._window_id += 1
window_id = self._window_id
if not self._handle_parameters_update(chunk):
return
bytes_acknowledged = chunk.offset
while True:
if self._chunk_delay_us:
await asyncio.sleep(self._chunk_delay_us / 1e6)
async with self._lock:
if self.done.is_set():
return
if window_id != self._window_id:
_LOG.debug('Transfer %d: Skipping stale window', self.id)
return
write_chunk = self._next_chunk()
self._offset += len(write_chunk.data)
sent_requested_bytes = self._offset == self._window_end_offset
self._send_chunk(write_chunk)
self._update_progress(self._offset, bytes_acknowledged,
len(self.data))
if sent_requested_bytes:
break
self._last_chunk = write_chunk
def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
"""Updates transfer state based on a transfer parameters update."""
retransmit = True
if chunk.HasField('type'):
retransmit = (
chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT
or chunk.type == transfer_pb2.Chunk.Type.TRANSFER_START)
if chunk.offset > len(self.data):
# Bad offset; terminate the transfer.
_LOG.error(
'Transfer %d: server requested invalid offset %d (size %d)',
self.id, chunk.offset, len(self.data))
self._send_error(Status.OUT_OF_RANGE)
return False
if chunk.pending_bytes == 0:
_LOG.error(
'Transfer %d: service requested 0 bytes (invalid); aborting',
self.id)
self._send_error(Status.INTERNAL)
return False
if retransmit:
# Check whether the client has sent a previous data offset, which
# indicates that some chunks were lost in transmission.
if chunk.offset < self._offset:
_LOG.debug('Write transfer %d rolling back: offset %d from %d',
self.id, chunk.offset, self._offset)
self._offset = chunk.offset
# Retransmit is the default behavior for older versions of the
# transfer protocol. The window_end_offset field is not guaranteed
# to be set in these version, so it must be calculated.
max_bytes_to_send = min(chunk.pending_bytes,
len(self.data) - self._offset)
self._window_end_offset = self._offset + max_bytes_to_send
else:
assert chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE
# Extend the window to the new end offset specified by the server.
self._window_end_offset = min(chunk.window_end_offset,
len(self.data))
if chunk.HasField('max_chunk_size_bytes'):
self._max_chunk_size = chunk.max_chunk_size_bytes
if chunk.HasField('min_delay_microseconds'):
self._chunk_delay_us = chunk.min_delay_microseconds
return True
def _retry_after_timeout(self) -> None:
self._send_chunk(self._last_chunk)
def _next_chunk(self) -> transfer_pb2.Chunk:
"""Returns the next Chunk message to send in the data transfer."""
chunk = transfer_pb2.Chunk(session_id=self.id,
offset=self._offset,
type=transfer_pb2.Chunk.Type.TRANSFER_DATA)
max_bytes_in_chunk = min(self._max_chunk_size,
self._window_end_offset - self._offset)
chunk.data = self.data[self._offset:self._offset + max_bytes_in_chunk]
# Mark the final chunk of the transfer.
if len(self.data) - self._offset <= max_bytes_in_chunk:
chunk.remaining_bytes = 0
return chunk
class ReadTransfer(Transfer):
"""A client <- server read transfer.
Although Python can effectively handle an unlimited transfer window, this
client sets a conservative window and chunk size to avoid overloading the
device. These are configurable in the constructor.
"""
# The fractional position within a window at which a receive transfer should
# extend its window size to minimize the amount of time the transmitter
# spends blocked.
#
# For example, a divisor of 2 will extend the window when half of the
# requested data has been received, a divisor of three will extend at a
# third of the window, and so on.
EXTEND_WINDOW_DIVISOR = 2
def __init__( # pylint: disable=too-many-arguments
self,
session_id: int,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
max_bytes_to_receive: int = 8192,
max_chunk_size: int = 1024,
chunk_delay_us: int = None,
progress_callback: ProgressCallback = None):
super().__init__(session_id, send_chunk, end_transfer,
response_timeout_s, initial_response_timeout_s,
max_retries, progress_callback)
self._max_bytes_to_receive = max_bytes_to_receive
self._max_chunk_size = max_chunk_size
self._chunk_delay_us = chunk_delay_us
self._remaining_transfer_size: Optional[int] = None
self._data = bytearray()
self._offset = 0
self._pending_bytes = max_bytes_to_receive
self._window_end_offset = max_bytes_to_receive
@property
def data(self) -> bytes:
"""Returns an immutable copy of the data that has been read."""
return bytes(self._data)
def _initial_chunk(self) -> transfer_pb2.Chunk:
return self._transfer_parameters(
transfer_pb2.Chunk.Type.TRANSFER_START)
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a read transfer, the client receives data chunks from the server.
Once all pending data is received, the transfer parameters are updated.
"""
if chunk.offset != self._offset:
# Initially, the transfer service only supports in-order transfers.
# If data is received out of order, request that the server
# retransmit from the previous offset.
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
return
self._data += chunk.data
self._pending_bytes -= len(chunk.data)
self._offset += len(chunk.data)
if chunk.HasField('remaining_bytes'):
if chunk.remaining_bytes == 0:
# No more data to read. Acknowledge receipt and finish.
self._send_chunk(
transfer_pb2.Chunk(
session_id=self.id,
status=Status.OK.value,
type=transfer_pb2.Chunk.Type.TRANSFER_COMPLETION))
self.finish(Status.OK)
return
# The server may indicate if the amount of remaining data is known.
self._remaining_transfer_size = chunk.remaining_bytes
elif self._remaining_transfer_size is not None:
# Update the remaining transfer size, if it is known.
self._remaining_transfer_size -= len(chunk.data)
# If the transfer size drops to zero, the estimate was inaccurate.
if self._remaining_transfer_size <= 0:
self._remaining_transfer_size = None
total_size = None if self._remaining_transfer_size is None else (
self._remaining_transfer_size + self._offset)
self._update_progress(self._offset, self._offset, total_size)
if chunk.window_end_offset != 0:
if chunk.window_end_offset < self._offset:
_LOG.error(
'Transfer %d: transmitter sent invalid earlier end offset '
'%d (receiver offset %d)', self.id,
chunk.window_end_offset, self._offset)
self._send_error(Status.INTERNAL)
return
if chunk.window_end_offset > self._window_end_offset:
_LOG.error(
'Transfer %d: transmitter sent invalid later end offset '
'%d (receiver end offset %d)', self.id,
chunk.window_end_offset, self._window_end_offset)
self._send_error(Status.INTERNAL)
return
self._window_end_offset = chunk.window_end_offset
self._pending_bytes -= chunk.window_end_offset - self._offset
remaining_window_size = self._window_end_offset - self._offset
extend_window = (remaining_window_size <= self._max_bytes_to_receive /
ReadTransfer.EXTEND_WINDOW_DIVISOR)
if self._pending_bytes == 0:
# All pending data was received. Send out a new parameters chunk for
# the next block.
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
elif extend_window:
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))
def _retry_after_timeout(self) -> None:
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
"""Sends an updated transfer parameters chunk to the server."""
self._pending_bytes = self._max_bytes_to_receive
self._window_end_offset = self._offset + self._max_bytes_to_receive
chunk = transfer_pb2.Chunk(session_id=self.id,
pending_bytes=self._pending_bytes,
window_end_offset=self._window_end_offset,
max_chunk_size_bytes=self._max_chunk_size,
offset=self._offset,
type=chunk_type)
if self._chunk_delay_us:
chunk.min_delay_microseconds = self._chunk_delay_us
return chunk