blob: dfddd12aaa0ddd779fbbcd7dca7c7c42145cd177 [file] [log] [blame]
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Classes for read and write transfers."""
import abc
import asyncio
from dataclasses import dataclass
import enum
import logging
import math
import threading
from typing import Any, Callable, Optional
from pw_status import Status
try:
from pw_transfer import transfer_pb2
except ImportError:
# For the bazel build, which puts generated protos in a different location.
from pigweed.pw_transfer import transfer_pb2 # type: ignore
_LOG = logging.getLogger(__package__)
@dataclass(frozen=True)
class ProgressStats:
bytes_sent: int
bytes_confirmed_received: int
total_size_bytes: Optional[int]
def percent_received(self) -> float:
if self.total_size_bytes is None:
return math.nan
return self.bytes_confirmed_received / self.total_size_bytes * 100
def __str__(self) -> str:
total = str(
self.total_size_bytes) if self.total_size_bytes else 'unknown'
return (f'{self.percent_received():5.1f}% ({self.bytes_sent} B sent, '
f'{self.bytes_confirmed_received} B received of {total} B)')
ProgressCallback = Callable[[ProgressStats], Any]
class _Timer:
"""A timer which invokes a callback after a certain timeout."""
def __init__(self, timeout_s: float, callback: Callable[[], Any]):
self.timeout_s = timeout_s
self._callback = callback
self._task: Optional[asyncio.Task[Any]] = None
def start(self, timeout_s: float = None) -> None:
"""Starts a new timer.
If a timer is already running, it is stopped and a new timer started.
This can be used to implement watchdog-like behavior, where a callback
is invoked after some time without a kick.
"""
self.stop()
timeout_s = self.timeout_s if timeout_s is None else timeout_s
self._task = asyncio.create_task(self._run(timeout_s))
def stop(self) -> None:
"""Terminates a running timer."""
if self._task is not None:
self._task.cancel()
self._task = None
async def _run(self, timeout_s: float) -> None:
await asyncio.sleep(timeout_s)
self._task = None
self._callback()
class Transfer(abc.ABC):
"""A client-side data transfer through a Manager.
Subclasses are responsible for implementing all of the logic for their type
of transfer, receiving messages from the server and sending the appropriate
messages in response.
"""
class _State(enum.Enum):
# Transfer is starting. The server and client are performing an initial
# handshake and negotiating protocol and feature flags.
INITIATING = 0
# Waiting for the other end to send a chunk.
WAITING = 1
# Transmitting a window of data to a receiver.
TRANSMITTING = 2
# Recovering after one or more chunks was dropped in an active transfer.
RECOVERY = 3
# Transfer has completed locally and is waiting for the peer to
# acknowledge its final status. Only entered by the terminating side of
# the transfer.
#
# The context remains in a TERMINATING state until it receives an
# acknowledgement from the peer or times out.
TERMINATING = 4
# A transfer has fully completed.
COMPLETE = 5
def __init__(self,
session_id: int,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[['Transfer'], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
progress_callback: ProgressCallback = None):
self.id = session_id
self.status = Status.OK
self.done = threading.Event()
self._send_chunk = send_chunk
self._end_transfer = end_transfer
# TODO(b/243679452): Only the legacy protocol begins in a WAITING state.
# Newer protocols should start as INITIATING.
self._state = Transfer._State.WAITING
self._retries = 0
self._max_retries = max_retries
self._response_timer = _Timer(response_timeout_s, self._on_timeout)
self._initial_response_timeout_s = initial_response_timeout_s
self._progress_callback = progress_callback
async def begin(self) -> None:
"""Sends the initial chunk of the transfer."""
self._send_chunk(self._initial_chunk())
self._response_timer.start(self._initial_response_timeout_s)
@property
@abc.abstractmethod
def data(self) -> bytes:
"""Returns the data read or written in this transfer."""
@abc.abstractmethod
def _initial_chunk(self) -> transfer_pb2.Chunk:
"""Returns the initial chunk to notify the sever of the transfer."""
async def handle_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
Handles terminating chunks (i.e. those with a status) and forwards
non-terminating chunks to handle_data_chunk.
"""
self._response_timer.stop()
self._retries = 0 # Received data from service, so reset the retries.
_LOG.debug('Received chunk\n%s', str(chunk).rstrip())
# Status chunks are only used to terminate a transfer. They do not
# contain any data that requires processing.
if chunk.HasField('status'):
self.finish(Status(chunk.status))
return
await self._handle_data_chunk(chunk)
# Start the timeout for the server to send a chunk in response.
self._response_timer.start()
@abc.abstractmethod
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Handles a chunk that contains or requests data."""
@abc.abstractmethod
def _retry_after_timeout(self) -> None:
"""Retries after a timeout occurs."""
def _on_timeout(self) -> None:
"""Handles a timeout while waiting for a chunk."""
if self._state is Transfer._State.COMPLETE:
return
self._retries += 1
if self._retries > self._max_retries:
self.finish(Status.DEADLINE_EXCEEDED)
return
_LOG.debug('Received no responses for %.3fs; retrying %d/%d',
self._response_timer.timeout_s, self._retries,
self._max_retries)
self._retry_after_timeout()
self._response_timer.start()
def finish(self, status: Status, skip_callback: bool = False) -> None:
"""Ends the transfer with the specified status."""
self._response_timer.stop()
self.status = status
if status.ok():
total_size = len(self.data)
self._update_progress(total_size, total_size, total_size)
if not skip_callback:
self._end_transfer(self)
# Set done last so that the transfer has been fully cleaned up.
self._state = Transfer._State.COMPLETE
self.done.set()
def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
total_size_bytes: Optional[int]) -> None:
"""Invokes the provided progress callback, if any, with the progress."""
stats = ProgressStats(bytes_sent, bytes_confirmed_received,
total_size_bytes)
_LOG.debug('Transfer %d progress: %s', self.id, stats)
if self._progress_callback:
self._progress_callback(stats)
def _send_error(self, error: Status) -> None:
"""Sends an error chunk to the server and finishes the transfer."""
self._send_chunk(
transfer_pb2.Chunk(transfer_id=self.id,
status=error.value,
type=transfer_pb2.Chunk.Type.COMPLETION))
self.finish(error)
class WriteTransfer(Transfer):
"""A client -> server write transfer."""
def __init__(
self,
session_id: int,
data: bytes,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
progress_callback: ProgressCallback = None,
):
super().__init__(session_id, send_chunk, end_transfer,
response_timeout_s, initial_response_timeout_s,
max_retries, progress_callback)
self._data = data
self._offset = 0
self._window_end_offset = 0
self._max_chunk_size = 0
self._chunk_delay_us: Optional[int] = None
# The window ID increments for each parameters update.
self._window_id = 0
self._bytes_confirmed_received = 0
self._last_chunk = self._initial_chunk()
@property
def data(self) -> bytes:
return self._data
def _initial_chunk(self) -> transfer_pb2.Chunk:
# TODO(b/243679452): In protocol v2, session_id should be assigned by
# the server during an initial handshake.
return transfer_pb2.Chunk(transfer_id=self.id,
resource_id=self.id,
type=transfer_pb2.Chunk.Type.START)
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a write transfer, the server only sends transfer parameter updates
to the client. When a message is received, update local parameters and
send data accordingly.
"""
if self._state is Transfer._State.TERMINATING:
# TODO(b/243679452): Handle termination in protocol v2.
return
if self._state is Transfer._State.TRANSMITTING:
self._state = Transfer._State.WAITING
assert self._state is Transfer._State.WAITING
if not self._handle_parameters_update(chunk):
return
self._bytes_confirmed_received = chunk.offset
self._state = Transfer._State.TRANSMITTING
self._window_id += 1
asyncio.create_task(self._transmit_next_chunk(self._window_id))
async def _transmit_next_chunk(self,
window_id: int,
timeout_us: int = None) -> None:
"""Transmits a single data chunk to the server.
If the chunk completes the active window, returns to a WAITING state.
Otherwise, schedules another transmission for the next chunk.
"""
if timeout_us is not None:
await asyncio.sleep(timeout_us / 1e6)
if self._state is not Transfer._State.TRANSMITTING:
return
if window_id != self._window_id:
_LOG.debug('Transfer %d: Skipping stale window', self.id)
return
chunk = self._next_chunk()
self._offset += len(chunk.data)
sent_requested_bytes = self._offset == self._window_end_offset
self._send_chunk(chunk)
self._update_progress(self._offset, self._bytes_confirmed_received,
len(self.data))
if sent_requested_bytes:
self._state = Transfer._State.WAITING
else:
asyncio.create_task(
self._transmit_next_chunk(window_id,
timeout_us=self._chunk_delay_us))
self._last_chunk = chunk
def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
"""Updates transfer state based on a transfer parameters update."""
retransmit = True
if chunk.HasField('type'):
retransmit = (chunk.type
== transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT
or chunk.type == transfer_pb2.Chunk.Type.START)
if chunk.offset > len(self.data):
# Bad offset; terminate the transfer.
_LOG.error(
'Transfer %d: server requested invalid offset %d (size %d)',
self.id, chunk.offset, len(self.data))
self._send_error(Status.OUT_OF_RANGE)
return False
if chunk.pending_bytes == 0:
_LOG.error(
'Transfer %d: service requested 0 bytes (invalid); aborting',
self.id)
self._send_error(Status.INTERNAL)
return False
if retransmit:
# Check whether the client has sent a previous data offset, which
# indicates that some chunks were lost in transmission.
if chunk.offset < self._offset:
_LOG.debug('Write transfer %d rolling back: offset %d from %d',
self.id, chunk.offset, self._offset)
self._offset = chunk.offset
# Retransmit is the default behavior for older versions of the
# transfer protocol. The window_end_offset field is not guaranteed
# to be set in these version, so it must be calculated.
max_bytes_to_send = min(chunk.pending_bytes,
len(self.data) - self._offset)
self._window_end_offset = self._offset + max_bytes_to_send
else:
assert chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE
# Extend the window to the new end offset specified by the server.
self._window_end_offset = min(chunk.window_end_offset,
len(self.data))
if chunk.HasField('max_chunk_size_bytes'):
self._max_chunk_size = chunk.max_chunk_size_bytes
if chunk.HasField('min_delay_microseconds'):
self._chunk_delay_us = chunk.min_delay_microseconds
return True
def _retry_after_timeout(self) -> None:
if self._state is Transfer._State.WAITING:
self._send_chunk(self._last_chunk)
# TODO(b/243679452): Handle other states in protocol v2.
def _next_chunk(self) -> transfer_pb2.Chunk:
"""Returns the next Chunk message to send in the data transfer."""
chunk = transfer_pb2.Chunk(transfer_id=self.id,
offset=self._offset,
type=transfer_pb2.Chunk.Type.DATA)
max_bytes_in_chunk = min(self._max_chunk_size,
self._window_end_offset - self._offset)
chunk.data = self.data[self._offset:self._offset + max_bytes_in_chunk]
# Mark the final chunk of the transfer.
if len(self.data) - self._offset <= max_bytes_in_chunk:
chunk.remaining_bytes = 0
return chunk
class ReadTransfer(Transfer):
"""A client <- server read transfer.
Although Python can effectively handle an unlimited transfer window, this
client sets a conservative window and chunk size to avoid overloading the
device. These are configurable in the constructor.
"""
# The fractional position within a window at which a receive transfer should
# extend its window size to minimize the amount of time the transmitter
# spends blocked.
#
# For example, a divisor of 2 will extend the window when half of the
# requested data has been received, a divisor of three will extend at a
# third of the window, and so on.
EXTEND_WINDOW_DIVISOR = 2
def __init__( # pylint: disable=too-many-arguments
self,
session_id: int,
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
max_retries: int,
max_bytes_to_receive: int = 8192,
max_chunk_size: int = 1024,
chunk_delay_us: int = None,
progress_callback: ProgressCallback = None):
super().__init__(session_id, send_chunk, end_transfer,
response_timeout_s, initial_response_timeout_s,
max_retries, progress_callback)
self._max_bytes_to_receive = max_bytes_to_receive
self._max_chunk_size = max_chunk_size
self._chunk_delay_us = chunk_delay_us
self._remaining_transfer_size: Optional[int] = None
self._data = bytearray()
self._offset = 0
self._pending_bytes = max_bytes_to_receive
self._window_end_offset = max_bytes_to_receive
self._last_chunk_offset: Optional[int] = None
@property
def data(self) -> bytes:
"""Returns an immutable copy of the data that has been read."""
return bytes(self._data)
def _initial_chunk(self) -> transfer_pb2.Chunk:
return self._transfer_parameters(transfer_pb2.Chunk.Type.START)
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a read transfer, the client receives data chunks from the server.
Once all pending data is received, the transfer parameters are updated.
"""
if self._state is Transfer._State.TERMINATING:
# TODO(b/243679452): Handle terminating states in protocol v2.
return
if self._state is Transfer._State.RECOVERY:
if chunk.offset != self._offset:
if self._last_chunk_offset == chunk.offset:
_LOG.debug(
'Transfer %d received repeated offset %d: '
'retry detected, resending transfer parameters',
self.id, chunk.offset)
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
else:
_LOG.debug(
'Transfer %d waiting for offset %d, ignoring %d',
self.id, self._offset, chunk.offset)
self._last_chunk_offset = chunk.offset
return
_LOG.info(
'Transfer %d received expected offset %d, resuming transfer',
self.id, chunk.offset)
self._state = Transfer._State.WAITING
assert self._state is Transfer._State.WAITING
if chunk.offset != self._offset:
# Initially, the transfer service only supports in-order transfers.
# If data is received out of order, request that the server
# retransmit from the previous offset.
_LOG.debug(
'Transfer %d expected offset %d, received %d: '
'entering recovery state', self.id, self._offset, chunk.offset)
self._state = Transfer._State.RECOVERY
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
return
self._data += chunk.data
self._pending_bytes -= len(chunk.data)
self._offset += len(chunk.data)
# Update the last offset seen so that retries can be detected.
self._last_chunk_offset = chunk.offset
if chunk.HasField('remaining_bytes'):
if chunk.remaining_bytes == 0:
# No more data to read. Acknowledge receipt and finish.
self._send_chunk(
transfer_pb2.Chunk(
transfer_id=self.id,
status=Status.OK.value,
type=transfer_pb2.Chunk.Type.COMPLETION))
self.finish(Status.OK)
return
# The server may indicate if the amount of remaining data is known.
self._remaining_transfer_size = chunk.remaining_bytes
elif self._remaining_transfer_size is not None:
# Update the remaining transfer size, if it is known.
self._remaining_transfer_size -= len(chunk.data)
# If the transfer size drops to zero, the estimate was inaccurate.
if self._remaining_transfer_size <= 0:
self._remaining_transfer_size = None
total_size = None if self._remaining_transfer_size is None else (
self._remaining_transfer_size + self._offset)
self._update_progress(self._offset, self._offset, total_size)
if chunk.window_end_offset != 0:
if chunk.window_end_offset < self._offset:
_LOG.error(
'Transfer %d: transmitter sent invalid earlier end offset '
'%d (receiver offset %d)', self.id,
chunk.window_end_offset, self._offset)
self._send_error(Status.INTERNAL)
return
if chunk.window_end_offset > self._window_end_offset:
_LOG.error(
'Transfer %d: transmitter sent invalid later end offset '
'%d (receiver end offset %d)', self.id,
chunk.window_end_offset, self._window_end_offset)
self._send_error(Status.INTERNAL)
return
self._window_end_offset = chunk.window_end_offset
self._pending_bytes -= chunk.window_end_offset - self._offset
remaining_window_size = self._window_end_offset - self._offset
extend_window = (remaining_window_size <= self._max_bytes_to_receive /
ReadTransfer.EXTEND_WINDOW_DIVISOR)
if self._pending_bytes == 0:
# All pending data was received. Send out a new parameters chunk for
# the next block.
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
elif extend_window:
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))
def _retry_after_timeout(self) -> None:
if (self._state is Transfer._State.WAITING
or self._state is Transfer._State.RECOVERY):
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
# TODO(b/243679452): Handle other states in protocol v2.
def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
"""Sends an updated transfer parameters chunk to the server."""
self._pending_bytes = self._max_bytes_to_receive
self._window_end_offset = self._offset + self._max_bytes_to_receive
chunk = transfer_pb2.Chunk(transfer_id=self.id,
pending_bytes=self._pending_bytes,
window_end_offset=self._window_end_offset,
max_chunk_size_bytes=self._max_chunk_size,
offset=self._offset,
type=chunk_type)
if self._chunk_delay_us:
chunk.min_delay_microseconds = self._chunk_delay_us
return chunk