pw_transfer: Add progress callback to write transfers in Python client
This allows users to optionally provide a transfer progress callback
when starting a write transfer from the Python client. If provided, the
callback is invoked with the state of the transfer as chunks are
received from the server.
Change-Id: I5dfaf239e3feb7bac161ca66b429f6084daad8e9
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/62023
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
Reviewed-by: Joe Ethier <jethier@google.com>
diff --git a/pw_transfer/docs.rst b/pw_transfer/docs.rst
index 5282e73..effad8c 100644
--- a/pw_transfer/docs.rst
+++ b/pw_transfer/docs.rst
@@ -78,7 +78,7 @@
Python
======
.. automodule:: pw_transfer.transfer
- :members: Manager, Error
+ :members: ProgressStats, Manager, Error
**Example**
diff --git a/pw_transfer/py/pw_transfer/transfer.py b/pw_transfer/py/pw_transfer/transfer.py
index 382aa8f..144940c 100644
--- a/pw_transfer/py/pw_transfer/transfer.py
+++ b/pw_transfer/py/pw_transfer/transfer.py
@@ -15,6 +15,7 @@
import abc
import asyncio
+from dataclasses import dataclass
import logging
import threading
from typing import Any, Callable, Dict, Optional
@@ -26,6 +27,21 @@
_LOG = logging.getLogger(__name__)
+@dataclass(frozen=True)
+class ProgressStats:
+ total_size_bytes: Optional[int]
+ bytes_confirmed_received: int
+ bytes_sent: int
+
+ def __str__(self) -> str:
+ total = str(
+ self.total_size_bytes) if self.total_size_bytes else 'unknown'
+ return f'{self.bytes_confirmed_received} / {total}'
+
+
+ProgressCallback = Callable[[ProgressStats], Any]
+
+
class _Timer:
"""A timer which invokes a callback after a certain timeout."""
def __init__(self, callback: Callable[[], Any]):
@@ -61,10 +77,13 @@
of transfer, receiving messages from the server and sending the appropriate
messages in response.
"""
- def __init__(self, transfer_id: int, data: bytes,
+ def __init__(self,
+ transfer_id: int,
+ data: bytes,
send_chunk: Callable[[Chunk], None],
- end_transfer: Callable[['_Transfer'],
- None], response_timer: _Timer):
+ end_transfer: Callable[['_Transfer'], None],
+ response_timer: _Timer,
+ progress_callback: ProgressCallback = None):
self.id = transfer_id
self.status = Status.OK
self.data = data
@@ -73,6 +92,7 @@
self._send_chunk = send_chunk
self._end_transfer = end_transfer
self._response_timer = response_timer
+ self._progress_callback = progress_callback
@abc.abstractmethod
async def begin(self) -> None:
@@ -85,13 +105,27 @@
Only called for non-terminating chunks (i.e. those without a status).
"""
+ @abc.abstractmethod
+ def progress_stats(self) -> ProgressStats:
+ """Returns the current progress of the transfer."""
+
def finish(self, status: Status) -> None:
"""Ends the transfer with the specified status."""
self._response_timer.stop()
self.status = status
+ self._invoke_progress_callback()
self.done.set()
self._end_transfer(self)
+ def _invoke_progress_callback(self) -> None:
+ """Invokes the provided progress callback, if any, with the progress."""
+
+ stats = self.progress_stats()
+ _LOG.debug('Transfer %d progress: %s', self.id, stats)
+
+ if self._progress_callback:
+ self._progress_callback(stats)
+
def _send_error(self, error: Status) -> None:
"""Sends an error chunk to the server and finishes the transfer."""
self._send_chunk(Chunk(transfer_id=self.id, status=error.value))
@@ -107,11 +141,14 @@
send_chunk: Callable[[Chunk], None],
end_transfer: Callable[[_Transfer], None],
response_timeout_s: float,
+ progress_callback: ProgressCallback = None,
):
super().__init__(transfer_id, data, send_chunk, end_transfer,
- _Timer(lambda: self.finish(Status.DEADLINE_EXCEEDED)))
+ _Timer(lambda: self.finish(Status.DEADLINE_EXCEEDED)),
+ progress_callback)
self._offset = 0
+ self._bytes_acknowledged = 0
self._response_timeout_s = response_timeout_s
self._max_bytes_to_send = 0
@@ -140,6 +177,7 @@
self.id, chunk.offset, self._offset)
self._offset = chunk.offset
+ self._bytes_acknowledged = chunk.offset
if self._offset > len(self.data):
# Bad offset; terminate the transfer.
@@ -170,11 +208,23 @@
self._send_chunk(write_chunk)
+ # Update the caller with the transfer server's current progress.
+ self._invoke_progress_callback()
+
if self._chunk_delay_us:
await asyncio.sleep(self._chunk_delay_us / 1e6)
+ # Assume that the server will acknowledge these bytes. If it doesn't,
+ # this will be reset on the next chunk.
+ self._bytes_acknowledged = self._offset
+
self._response_timer.start(self._response_timeout_s)
+ def progress_stats(self) -> ProgressStats:
+ return ProgressStats(total_size_bytes=len(self.data),
+ bytes_confirmed_received=self._bytes_acknowledged,
+ bytes_sent=self._offset)
+
def _next_chunk(self) -> Chunk:
"""Returns the next Chunk message to send in the data transfer."""
chunk = Chunk(transfer_id=self.id, offset=self._offset)
@@ -205,9 +255,10 @@
max_retries: int = 3,
max_bytes_to_receive: int = 8192,
max_chunk_size: int = 1024,
- chunk_delay_us: int = None):
+ chunk_delay_us: int = None,
+ progress_callback: ProgressCallback = None):
super().__init__(transfer_id, bytes(), send_chunk, end_transfer,
- _Timer(self._on_timeout))
+ _Timer(self._on_timeout), progress_callback)
self._response_timeout_s = response_timeout_s
self._chunk_timeout_count = 0
@@ -247,6 +298,8 @@
self._pending_bytes -= len(chunk.data)
self._offset += len(chunk.data)
+ self._invoke_progress_callback()
+
if chunk.HasField('remaining_bytes'):
if chunk.remaining_bytes == 0:
# No more data to read. Acknowledge receipt and finish.
@@ -268,6 +321,12 @@
# Wait for the next pending chunk.
self._response_timer.start(self._response_timeout_s)
+ def progress_stats(self) -> ProgressStats:
+ # TODO(frolv): Get the transfer size from the server if it is known.
+ return ProgressStats(total_size_bytes=None,
+ bytes_confirmed_received=self._offset,
+ bytes_sent=self._offset)
+
def _send_transfer_parameters(self):
"""Sends an updated transfer parameters chunk to the server."""
@@ -352,7 +411,9 @@
self._loop.call_soon_threadsafe(self._quit_event.set)
self._thread.join()
- def read(self, transfer_id: int) -> bytes:
+ def read(self,
+ transfer_id: int,
+ progress_callback: ProgressCallback = None) -> bytes:
"""Receives ("downloads") data from the server.
Raises:
@@ -362,9 +423,11 @@
if transfer_id in self._read_transfers:
raise ValueError(f'Read transfer {transfer_id} already exists')
- transfer = _ReadTransfer(transfer_id, self._send_read_chunk,
+ transfer = _ReadTransfer(transfer_id,
+ self._send_read_chunk,
self._end_read_transfer,
- self._default_response_timeout_s)
+ self._default_response_timeout_s,
+ progress_callback=progress_callback)
self._start_read_transfer(transfer)
transfer.done.wait()
@@ -374,9 +437,19 @@
return transfer.data
- def write(self, transfer_id: int, data: bytes) -> None:
+ def write(self,
+ transfer_id: int,
+ data: bytes,
+ progress_callback: ProgressCallback = None) -> None:
"""Transmits ("uploads") data to the server.
+ Args:
+ transfer_id: ID of the write transfer
+ data: Data to send to the server.
+ progress_callback: Optional callback periodically invoked throughout
+ the transfer with the transfer state. Can be used to provide user-
+ facing status updates such as progress bars.
+
Raises:
Error: the transfer failed to complete
"""
@@ -384,9 +457,12 @@
if transfer_id in self._write_transfers:
raise ValueError(f'Write transfer {transfer_id} already exists')
- transfer = _WriteTransfer(transfer_id, data, self._send_write_chunk,
+ transfer = _WriteTransfer(transfer_id,
+ data,
+ self._send_write_chunk,
self._end_write_transfer,
- self._default_response_timeout_s)
+ self._default_response_timeout_s,
+ progress_callback=progress_callback)
self._start_write_transfer(transfer)
transfer.done.wait()
diff --git a/pw_transfer/py/tests/transfer_test.py b/pw_transfer/py/tests/transfer_test.py
index e83a950..2d04df3 100644
--- a/pw_transfer/py/tests/transfer_test.py
+++ b/pw_transfer/py/tests/transfer_test.py
@@ -126,6 +126,34 @@
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
+ def test_read_transfer_progress_callback(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ ((
+ transfer_pb2.Chunk(
+ transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
+ transfer_pb2.Chunk(
+ transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
+ ), ),
+ )
+
+ progress = []
+
+ data = manager.read(3, progress.append)
+ self.assertEqual(data, b'abcdef')
+ self.assertEqual(len(self._sent_chunks), 2)
+ self.assertTrue(self._sent_chunks[-1].HasField('status'))
+ self.assertEqual(self._sent_chunks[-1].status, 0)
+ self.assertEqual(len(progress), 3)
+ self.assertEqual(progress, [
+ transfer.ProgressStats(None, 3, 3),
+ transfer.ProgressStats(None, 6, 6),
+ transfer.ProgressStats(None, 6, 6)
+ ])
+
def test_read_transfer_retry_bad_offset(self):
"""Server responds with an unexpected offset in a read transfer."""
manager = transfer.Manager(self._service,
@@ -299,6 +327,39 @@
self.assertEqual(self._sent_chunks[1].data, b'data to ')
self.assertEqual(self._sent_chunks[2].data, b'write')
+ def test_write_transfer_progress_callback(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=8,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ progress = []
+
+ manager.write(4, b'data to write', progress.append)
+ self.assertEqual(len(self._sent_chunks), 3)
+ self.assertEqual(self._received_data(), b'data to write')
+ self.assertEqual(self._sent_chunks[1].data, b'data to ')
+ self.assertEqual(self._sent_chunks[2].data, b'write')
+ self.assertEqual(len(progress), 3)
+ self.assertEqual(progress, [
+ transfer.ProgressStats(13, 0, 8),
+ transfer.ProgressStats(13, 8, 13),
+ transfer.ProgressStats(13, 13, 13)
+ ])
+
def test_write_transfer_rewind(self):
"""Write transfer in which the server re-requests an earlier offset."""
manager = transfer.Manager(self._service,