pw_transfer: Add progress callback to write transfers in Python client

This allows users to optionally provide a transfer progress callback
when starting a write transfer from the Python client. If provided, the
callback is invoked with the state of the transfer as chunks are
received from the server.

Change-Id: I5dfaf239e3feb7bac161ca66b429f6084daad8e9
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/62023
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
Reviewed-by: Joe Ethier <jethier@google.com>
diff --git a/pw_transfer/docs.rst b/pw_transfer/docs.rst
index 5282e73..effad8c 100644
--- a/pw_transfer/docs.rst
+++ b/pw_transfer/docs.rst
@@ -78,7 +78,7 @@
 Python
 ======
 .. automodule:: pw_transfer.transfer
-  :members: Manager, Error
+  :members: ProgressStats, Manager, Error
 
 **Example**
 
diff --git a/pw_transfer/py/pw_transfer/transfer.py b/pw_transfer/py/pw_transfer/transfer.py
index 382aa8f..144940c 100644
--- a/pw_transfer/py/pw_transfer/transfer.py
+++ b/pw_transfer/py/pw_transfer/transfer.py
@@ -15,6 +15,7 @@
 
 import abc
 import asyncio
+from dataclasses import dataclass
 import logging
 import threading
 from typing import Any, Callable, Dict, Optional
@@ -26,6 +27,21 @@
 _LOG = logging.getLogger(__name__)
 
 
+@dataclass(frozen=True)
+class ProgressStats:
+    total_size_bytes: Optional[int]
+    bytes_confirmed_received: int
+    bytes_sent: int
+
+    def __str__(self) -> str:
+        total = str(
+            self.total_size_bytes) if self.total_size_bytes else 'unknown'
+        return f'{self.bytes_confirmed_received} / {total}'
+
+
+ProgressCallback = Callable[[ProgressStats], Any]
+
+
 class _Timer:
     """A timer which invokes a callback after a certain timeout."""
     def __init__(self, callback: Callable[[], Any]):
@@ -61,10 +77,13 @@
     of transfer, receiving messages from the server and sending the appropriate
     messages in response.
     """
-    def __init__(self, transfer_id: int, data: bytes,
+    def __init__(self,
+                 transfer_id: int,
+                 data: bytes,
                  send_chunk: Callable[[Chunk], None],
-                 end_transfer: Callable[['_Transfer'],
-                                        None], response_timer: _Timer):
+                 end_transfer: Callable[['_Transfer'], None],
+                 response_timer: _Timer,
+                 progress_callback: ProgressCallback = None):
         self.id = transfer_id
         self.status = Status.OK
         self.data = data
@@ -73,6 +92,7 @@
         self._send_chunk = send_chunk
         self._end_transfer = end_transfer
         self._response_timer = response_timer
+        self._progress_callback = progress_callback
 
     @abc.abstractmethod
     async def begin(self) -> None:
@@ -85,13 +105,27 @@
         Only called for non-terminating chunks (i.e. those without a status).
         """
 
+    @abc.abstractmethod
+    def progress_stats(self) -> ProgressStats:
+        """Returns the current progress of the transfer."""
+
     def finish(self, status: Status) -> None:
         """Ends the transfer with the specified status."""
         self._response_timer.stop()
         self.status = status
+        self._invoke_progress_callback()
         self.done.set()
         self._end_transfer(self)
 
+    def _invoke_progress_callback(self) -> None:
+        """Invokes the provided progress callback, if any, with the progress."""
+
+        stats = self.progress_stats()
+        _LOG.debug('Transfer %d progress: %s', self.id, stats)
+
+        if self._progress_callback:
+            self._progress_callback(stats)
+
     def _send_error(self, error: Status) -> None:
         """Sends an error chunk to the server and finishes the transfer."""
         self._send_chunk(Chunk(transfer_id=self.id, status=error.value))
@@ -107,11 +141,14 @@
         send_chunk: Callable[[Chunk], None],
         end_transfer: Callable[[_Transfer], None],
         response_timeout_s: float,
+        progress_callback: ProgressCallback = None,
     ):
         super().__init__(transfer_id, data, send_chunk, end_transfer,
-                         _Timer(lambda: self.finish(Status.DEADLINE_EXCEEDED)))
+                         _Timer(lambda: self.finish(Status.DEADLINE_EXCEEDED)),
+                         progress_callback)
 
         self._offset = 0
+        self._bytes_acknowledged = 0
         self._response_timeout_s = response_timeout_s
 
         self._max_bytes_to_send = 0
@@ -140,6 +177,7 @@
                        self.id, chunk.offset, self._offset)
 
         self._offset = chunk.offset
+        self._bytes_acknowledged = chunk.offset
 
         if self._offset > len(self.data):
             # Bad offset; terminate the transfer.
@@ -170,11 +208,23 @@
 
             self._send_chunk(write_chunk)
 
+            # Update the caller with the transfer server's current progress.
+            self._invoke_progress_callback()
+
             if self._chunk_delay_us:
                 await asyncio.sleep(self._chunk_delay_us / 1e6)
 
+        # Assume that the server will acknowledge these bytes. If it doesn't,
+        # this will be reset on the next chunk.
+        self._bytes_acknowledged = self._offset
+
         self._response_timer.start(self._response_timeout_s)
 
+    def progress_stats(self) -> ProgressStats:
+        return ProgressStats(total_size_bytes=len(self.data),
+                             bytes_confirmed_received=self._bytes_acknowledged,
+                             bytes_sent=self._offset)
+
     def _next_chunk(self) -> Chunk:
         """Returns the next Chunk message to send in the data transfer."""
         chunk = Chunk(transfer_id=self.id, offset=self._offset)
@@ -205,9 +255,10 @@
                  max_retries: int = 3,
                  max_bytes_to_receive: int = 8192,
                  max_chunk_size: int = 1024,
-                 chunk_delay_us: int = None):
+                 chunk_delay_us: int = None,
+                 progress_callback: ProgressCallback = None):
         super().__init__(transfer_id, bytes(), send_chunk, end_transfer,
-                         _Timer(self._on_timeout))
+                         _Timer(self._on_timeout), progress_callback)
 
         self._response_timeout_s = response_timeout_s
         self._chunk_timeout_count = 0
@@ -247,6 +298,8 @@
         self._pending_bytes -= len(chunk.data)
         self._offset += len(chunk.data)
 
+        self._invoke_progress_callback()
+
         if chunk.HasField('remaining_bytes'):
             if chunk.remaining_bytes == 0:
                 # No more data to read. Acknowledge receipt and finish.
@@ -268,6 +321,12 @@
             # Wait for the next pending chunk.
             self._response_timer.start(self._response_timeout_s)
 
+    def progress_stats(self) -> ProgressStats:
+        # TODO(frolv): Get the transfer size from the server if it is known.
+        return ProgressStats(total_size_bytes=None,
+                             bytes_confirmed_received=self._offset,
+                             bytes_sent=self._offset)
+
     def _send_transfer_parameters(self):
         """Sends an updated transfer parameters chunk to the server."""
 
@@ -352,7 +411,9 @@
             self._loop.call_soon_threadsafe(self._quit_event.set)
             self._thread.join()
 
-    def read(self, transfer_id: int) -> bytes:
+    def read(self,
+             transfer_id: int,
+             progress_callback: ProgressCallback = None) -> bytes:
         """Receives ("downloads") data from the server.
 
         Raises:
@@ -362,9 +423,11 @@
         if transfer_id in self._read_transfers:
             raise ValueError(f'Read transfer {transfer_id} already exists')
 
-        transfer = _ReadTransfer(transfer_id, self._send_read_chunk,
+        transfer = _ReadTransfer(transfer_id,
+                                 self._send_read_chunk,
                                  self._end_read_transfer,
-                                 self._default_response_timeout_s)
+                                 self._default_response_timeout_s,
+                                 progress_callback=progress_callback)
         self._start_read_transfer(transfer)
 
         transfer.done.wait()
@@ -374,9 +437,19 @@
 
         return transfer.data
 
-    def write(self, transfer_id: int, data: bytes) -> None:
+    def write(self,
+              transfer_id: int,
+              data: bytes,
+              progress_callback: ProgressCallback = None) -> None:
         """Transmits ("uploads") data to the server.
 
+        Args:
+          transfer_id: ID of the write transfer
+          data: Data to send to the server.
+          progress_callback: Optional callback periodically invoked throughout
+              the transfer with the transfer state. Can be used to provide user-
+              facing status updates such as progress bars.
+
         Raises:
           Error: the transfer failed to complete
         """
@@ -384,9 +457,12 @@
         if transfer_id in self._write_transfers:
             raise ValueError(f'Write transfer {transfer_id} already exists')
 
-        transfer = _WriteTransfer(transfer_id, data, self._send_write_chunk,
+        transfer = _WriteTransfer(transfer_id,
+                                  data,
+                                  self._send_write_chunk,
                                   self._end_write_transfer,
-                                  self._default_response_timeout_s)
+                                  self._default_response_timeout_s,
+                                  progress_callback=progress_callback)
         self._start_write_transfer(transfer)
 
         transfer.done.wait()
diff --git a/pw_transfer/py/tests/transfer_test.py b/pw_transfer/py/tests/transfer_test.py
index e83a950..2d04df3 100644
--- a/pw_transfer/py/tests/transfer_test.py
+++ b/pw_transfer/py/tests/transfer_test.py
@@ -126,6 +126,34 @@
         self.assertTrue(self._sent_chunks[-1].HasField('status'))
         self.assertEqual(self._sent_chunks[-1].status, 0)
 
+    def test_read_transfer_progress_callback(self):
+        manager = transfer.Manager(self._service,
+                                   default_response_timeout_s=0.3)
+
+        self._enqueue_server_responses(
+            _Method.READ,
+            ((
+                transfer_pb2.Chunk(
+                    transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
+                transfer_pb2.Chunk(
+                    transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
+            ), ),
+        )
+
+        progress = []
+
+        data = manager.read(3, progress.append)
+        self.assertEqual(data, b'abcdef')
+        self.assertEqual(len(self._sent_chunks), 2)
+        self.assertTrue(self._sent_chunks[-1].HasField('status'))
+        self.assertEqual(self._sent_chunks[-1].status, 0)
+        self.assertEqual(len(progress), 3)
+        self.assertEqual(progress, [
+            transfer.ProgressStats(None, 3, 3),
+            transfer.ProgressStats(None, 6, 6),
+            transfer.ProgressStats(None, 6, 6)
+        ])
+
     def test_read_transfer_retry_bad_offset(self):
         """Server responds with an unexpected offset in a read transfer."""
         manager = transfer.Manager(self._service,
@@ -299,6 +327,39 @@
         self.assertEqual(self._sent_chunks[1].data, b'data to ')
         self.assertEqual(self._sent_chunks[2].data, b'write')
 
+    def test_write_transfer_progress_callback(self):
+        manager = transfer.Manager(self._service,
+                                   default_response_timeout_s=0.3)
+
+        self._enqueue_server_responses(
+            _Method.WRITE,
+            (
+                (transfer_pb2.Chunk(transfer_id=4,
+                                    offset=0,
+                                    pending_bytes=8,
+                                    max_chunk_size_bytes=8), ),
+                (transfer_pb2.Chunk(transfer_id=4,
+                                    offset=8,
+                                    pending_bytes=8,
+                                    max_chunk_size_bytes=8), ),
+                (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+            ),
+        )
+
+        progress = []
+
+        manager.write(4, b'data to write', progress.append)
+        self.assertEqual(len(self._sent_chunks), 3)
+        self.assertEqual(self._received_data(), b'data to write')
+        self.assertEqual(self._sent_chunks[1].data, b'data to ')
+        self.assertEqual(self._sent_chunks[2].data, b'write')
+        self.assertEqual(len(progress), 3)
+        self.assertEqual(progress, [
+            transfer.ProgressStats(13, 0, 8),
+            transfer.ProgressStats(13, 8, 13),
+            transfer.ProgressStats(13, 13, 13)
+        ])
+
     def test_write_transfer_rewind(self):
         """Write transfer in which the server re-requests an earlier offset."""
         manager = transfer.Manager(self._service,