pw_transfer: Refactor Python client as a state machine

This rewrites parts of the Python transfer client to function as a state
machine, similar to the C++ and Java clients. This is done in
preparation for supporting the v2 transfer protocol with opening and
closing handshake phases.

Change-Id: Ic9ea2c905425f452d8d7694c66b9bc4eac5cfe40
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/108062
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_transfer/py/BUILD.gn b/pw_transfer/py/BUILD.gn
index 95ed4db..d547e20 100644
--- a/pw_transfer/py/BUILD.gn
+++ b/pw_transfer/py/BUILD.gn
@@ -22,7 +22,7 @@
   generate_setup = {
     metadata = {
       name = "pw_transfer"
-      version = "0.0.1"
+      version = "0.1.0"
     }
   }
   sources = [
diff --git a/pw_transfer/py/pw_transfer/transfer.py b/pw_transfer/py/pw_transfer/transfer.py
index 22e650e..dfddd12 100644
--- a/pw_transfer/py/pw_transfer/transfer.py
+++ b/pw_transfer/py/pw_transfer/transfer.py
@@ -16,6 +16,7 @@
 import abc
 import asyncio
 from dataclasses import dataclass
+import enum
 import logging
 import math
 import threading
@@ -90,6 +91,31 @@
     of transfer, receiving messages from the server and sending the appropriate
     messages in response.
     """
+    class _State(enum.Enum):
+        # Transfer is starting. The server and client are performing an initial
+        # handshake and negotiating protocol and feature flags.
+        INITIATING = 0
+
+        # Waiting for the other end to send a chunk.
+        WAITING = 1
+
+        # Transmitting a window of data to a receiver.
+        TRANSMITTING = 2
+
+        # Recovering after one or more chunks was dropped in an active transfer.
+        RECOVERY = 3
+
+        # Transfer has completed locally and is waiting for the peer to
+        # acknowledge its final status. Only entered by the terminating side of
+        # the transfer.
+        #
+        # The context remains in a TERMINATING state until it receives an
+        # acknowledgement from the peer or times out.
+        TERMINATING = 4
+
+        # A transfer has fully completed.
+        COMPLETE = 5
+
     def __init__(self,
                  session_id: int,
                  send_chunk: Callable[[transfer_pb2.Chunk], None],
@@ -105,6 +131,10 @@
         self._send_chunk = send_chunk
         self._end_transfer = end_transfer
 
+        # TODO(b/243679452): Only the legacy protocol begins in a WAITING state.
+        # Newer protocols should start as INITIATING.
+        self._state = Transfer._State.WAITING
+
         self._retries = 0
         self._max_retries = max_retries
         self._response_timer = _Timer(response_timeout_s, self._on_timeout)
@@ -158,7 +188,7 @@
 
     def _on_timeout(self) -> None:
         """Handles a timeout while waiting for a chunk."""
-        if self.done.is_set():
+        if self._state is Transfer._State.COMPLETE:
             return
 
         self._retries += 1
@@ -185,6 +215,7 @@
             self._end_transfer(self)
 
         # Set done last so that the transfer has been fully cleaned up.
+        self._state = Transfer._State.COMPLETE
         self.done.set()
 
     def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
@@ -225,9 +256,6 @@
                          max_retries, progress_callback)
         self._data = data
 
-        # Guard this class with a lock since a transfer parameters update might
-        # arrive while responding to a prior update.
-        self._lock = asyncio.Lock()
         self._offset = 0
         self._window_end_offset = 0
         self._max_chunk_size = 0
@@ -236,6 +264,7 @@
         # The window ID increments for each parameters update.
         self._window_id = 0
 
+        self._bytes_confirmed_received = 0
         self._last_chunk = self._initial_chunk()
 
     @property
@@ -243,8 +272,8 @@
         return self._data
 
     def _initial_chunk(self) -> transfer_pb2.Chunk:
-        # TODO(frolv): session_id should not be set here, but assigned by the
-        # server during an initial handshake.
+        # TODO(b/243679452): In protocol v2, session_id should be assigned by
+        # the server during an initial handshake.
         return transfer_pb2.Chunk(transfer_id=self.id,
                                   resource_id=self.id,
                                   type=transfer_pb2.Chunk.Type.START)
@@ -257,40 +286,59 @@
         send data accordingly.
         """
 
-        async with self._lock:
-            self._window_id += 1
-            window_id = self._window_id
+        if self._state is Transfer._State.TERMINATING:
+            # TODO(b/243679452): Handle termination in protocol v2.
+            return
 
-            if not self._handle_parameters_update(chunk):
-                return
+        if self._state is Transfer._State.TRANSMITTING:
+            self._state = Transfer._State.WAITING
 
-            bytes_acknowledged = chunk.offset
+        assert self._state is Transfer._State.WAITING
 
-        while True:
-            if self._chunk_delay_us:
-                await asyncio.sleep(self._chunk_delay_us / 1e6)
+        if not self._handle_parameters_update(chunk):
+            return
 
-            async with self._lock:
-                if self.done.is_set():
-                    return
+        self._bytes_confirmed_received = chunk.offset
+        self._state = Transfer._State.TRANSMITTING
 
-                if window_id != self._window_id:
-                    _LOG.debug('Transfer %d: Skipping stale window', self.id)
-                    return
+        self._window_id += 1
+        asyncio.create_task(self._transmit_next_chunk(self._window_id))
 
-                write_chunk = self._next_chunk()
-                self._offset += len(write_chunk.data)
-                sent_requested_bytes = self._offset == self._window_end_offset
+    async def _transmit_next_chunk(self,
+                                   window_id: int,
+                                   timeout_us: int = None) -> None:
+        """Transmits a single data chunk to the server.
 
-            self._send_chunk(write_chunk)
+        If the chunk completes the active window, returns to a WAITING state.
+        Otherwise, schedules another transmission for the next chunk.
+        """
+        if timeout_us is not None:
+            await asyncio.sleep(timeout_us / 1e6)
 
-            self._update_progress(self._offset, bytes_acknowledged,
-                                  len(self.data))
+        if self._state is not Transfer._State.TRANSMITTING:
+            return
 
-            if sent_requested_bytes:
-                break
+        if window_id != self._window_id:
+            _LOG.debug('Transfer %d: Skipping stale window', self.id)
+            return
 
-        self._last_chunk = write_chunk
+        chunk = self._next_chunk()
+        self._offset += len(chunk.data)
+
+        sent_requested_bytes = self._offset == self._window_end_offset
+
+        self._send_chunk(chunk)
+        self._update_progress(self._offset, self._bytes_confirmed_received,
+                              len(self.data))
+
+        if sent_requested_bytes:
+            self._state = Transfer._State.WAITING
+        else:
+            asyncio.create_task(
+                self._transmit_next_chunk(window_id,
+                                          timeout_us=self._chunk_delay_us))
+
+        self._last_chunk = chunk
 
     def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
         """Updates transfer state based on a transfer parameters update."""
@@ -348,7 +396,10 @@
         return True
 
     def _retry_after_timeout(self) -> None:
-        self._send_chunk(self._last_chunk)
+        if self._state is Transfer._State.WAITING:
+            self._send_chunk(self._last_chunk)
+
+        # TODO(b/243679452): Handle other states in protocol v2.
 
     def _next_chunk(self) -> transfer_pb2.Chunk:
         """Returns the next Chunk message to send in the data transfer."""
@@ -408,6 +459,7 @@
         self._offset = 0
         self._pending_bytes = max_bytes_to_receive
         self._window_end_offset = max_bytes_to_receive
+        self._last_chunk_offset: Optional[int] = None
 
     @property
     def data(self) -> bytes:
@@ -424,10 +476,43 @@
         Once all pending data is received, the transfer parameters are updated.
         """
 
+        if self._state is Transfer._State.TERMINATING:
+            # TODO(b/243679452): Handle terminating states in protocol v2.
+            return
+
+        if self._state is Transfer._State.RECOVERY:
+            if chunk.offset != self._offset:
+                if self._last_chunk_offset == chunk.offset:
+                    _LOG.debug(
+                        'Transfer %d received repeated offset %d: '
+                        'retry detected, resending transfer parameters',
+                        self.id, chunk.offset)
+                    self._send_chunk(
+                        self._transfer_parameters(
+                            transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+                else:
+                    _LOG.debug(
+                        'Transfer %d waiting for offset %d, ignoring %d',
+                        self.id, self._offset, chunk.offset)
+                    self._last_chunk_offset = chunk.offset
+                return
+
+            _LOG.info(
+                'Transfer %d received expected offset %d, resuming transfer',
+                self.id, chunk.offset)
+            self._state = Transfer._State.WAITING
+
+        assert self._state is Transfer._State.WAITING
+
         if chunk.offset != self._offset:
             # Initially, the transfer service only supports in-order transfers.
             # If data is received out of order, request that the server
             # retransmit from the previous offset.
+            _LOG.debug(
+                'Transfer %d expected offset %d, received %d: '
+                'entering recovery state', self.id, self._offset, chunk.offset)
+            self._state = Transfer._State.RECOVERY
+
             self._send_chunk(
                 self._transfer_parameters(
                     transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
@@ -437,6 +522,9 @@
         self._pending_bytes -= len(chunk.data)
         self._offset += len(chunk.data)
 
+        # Update the last offset seen so that retries can be detected.
+        self._last_chunk_offset = chunk.offset
+
         if chunk.HasField('remaining_bytes'):
             if chunk.remaining_bytes == 0:
                 # No more data to read. Acknowledge receipt and finish.
@@ -498,9 +586,13 @@
                     transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))
 
     def _retry_after_timeout(self) -> None:
-        self._send_chunk(
-            self._transfer_parameters(
-                transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+        if (self._state is Transfer._State.WAITING
+                or self._state is Transfer._State.RECOVERY):
+            self._send_chunk(
+                self._transfer_parameters(
+                    transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+
+        # TODO(b/243679452): Handle other states in protocol v2.
 
     def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
         """Sends an updated transfer parameters chunk to the server."""
diff --git a/pw_transfer/py/tests/transfer_test.py b/pw_transfer/py/tests/transfer_test.py
index 63a73ee..1af8c35 100644
--- a/pw_transfer/py/tests/transfer_test.py
+++ b/pw_transfer/py/tests/transfer_test.py
@@ -200,6 +200,68 @@
         self.assertTrue(self._sent_chunks[-1].HasField('status'))
         self.assertEqual(self._sent_chunks[-1].status, 0)
 
+    def test_read_transfer_recovery_sends_parameters_on_retry(self) -> None:
+        """Server sends the same chunk twice (retry) in a read transfer."""
+        manager = pw_transfer.Manager(
+            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
+
+        self._enqueue_server_responses(
+            _Method.READ,
+            (
+                (
+                    # Bad offset, enter recovery state. Only one parameters
+                    # chunk should be sent.
+                    transfer_pb2.Chunk(transfer_id=3,
+                                       offset=1,
+                                       data=b'234',
+                                       remaining_bytes=5),
+                    transfer_pb2.Chunk(transfer_id=3,
+                                       offset=4,
+                                       data=b'567',
+                                       remaining_bytes=2),
+                    transfer_pb2.Chunk(
+                        transfer_id=3, offset=7, data=b'8', remaining_bytes=1),
+                ),
+                (
+                    # Only one parameters chunk should be sent after the server
+                    # retries the same offset twice.
+                    transfer_pb2.Chunk(transfer_id=3,
+                                       offset=1,
+                                       data=b'234',
+                                       remaining_bytes=5),
+                    transfer_pb2.Chunk(transfer_id=3,
+                                       offset=4,
+                                       data=b'567',
+                                       remaining_bytes=2),
+                    transfer_pb2.Chunk(
+                        transfer_id=3,
+                        offset=7, data=b'8', remaining_bytes=1),
+                    transfer_pb2.Chunk(
+                        transfer_id=3,
+                        offset=7, data=b'8', remaining_bytes=1),
+                ),
+                (transfer_pb2.Chunk(transfer_id=3,
+                                    offset=0,
+                                    data=b'123456789',
+                                    remaining_bytes=0), ),
+            ))
+
+        data = manager.read(3)
+        self.assertEqual(data, b'123456789')
+
+        self.assertEqual(len(self._sent_chunks), 4)
+        self.assertEqual(self._sent_chunks[0].type,
+                         transfer_pb2.Chunk.Type.START)
+        self.assertEqual(self._sent_chunks[0].offset, 0)
+        self.assertEqual(self._sent_chunks[1].type,
+                         transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)
+        self.assertEqual(self._sent_chunks[1].offset, 0)
+        self.assertEqual(self._sent_chunks[2].type,
+                         transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)
+        self.assertEqual(self._sent_chunks[2].offset, 0)
+        self.assertEqual(self._sent_chunks[3].type,
+                         transfer_pb2.Chunk.Type.COMPLETION)
+
     def test_read_transfer_retry_timeout(self) -> None:
         """Server doesn't respond to read transfer parameters."""
         manager = pw_transfer.Manager(