pw_transfer: Refactor Python client as a state machine
This rewrites parts of the Python transfer client to function as a state
machine, similar to the C++ and Java clients. This is done in
preparation for supporting the v2 transfer protocol with opening and
closing handshake phases.
Change-Id: Ic9ea2c905425f452d8d7694c66b9bc4eac5cfe40
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/108062
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_transfer/py/BUILD.gn b/pw_transfer/py/BUILD.gn
index 95ed4db..d547e20 100644
--- a/pw_transfer/py/BUILD.gn
+++ b/pw_transfer/py/BUILD.gn
@@ -22,7 +22,7 @@
generate_setup = {
metadata = {
name = "pw_transfer"
- version = "0.0.1"
+ version = "0.1.0"
}
}
sources = [
diff --git a/pw_transfer/py/pw_transfer/transfer.py b/pw_transfer/py/pw_transfer/transfer.py
index 22e650e..dfddd12 100644
--- a/pw_transfer/py/pw_transfer/transfer.py
+++ b/pw_transfer/py/pw_transfer/transfer.py
@@ -16,6 +16,7 @@
import abc
import asyncio
from dataclasses import dataclass
+import enum
import logging
import math
import threading
@@ -90,6 +91,31 @@
of transfer, receiving messages from the server and sending the appropriate
messages in response.
"""
+ class _State(enum.Enum):
+ # Transfer is starting. The server and client are performing an initial
+ # handshake and negotiating protocol and feature flags.
+ INITIATING = 0
+
+ # Waiting for the other end to send a chunk.
+ WAITING = 1
+
+ # Transmitting a window of data to a receiver.
+ TRANSMITTING = 2
+
+ # Recovering after one or more chunks was dropped in an active transfer.
+ RECOVERY = 3
+
+ # Transfer has completed locally and is waiting for the peer to
+ # acknowledge its final status. Only entered by the terminating side of
+ # the transfer.
+ #
+ # The context remains in a TERMINATING state until it receives an
+ # acknowledgement from the peer or times out.
+ TERMINATING = 4
+
+ # A transfer has fully completed.
+ COMPLETE = 5
+
def __init__(self,
session_id: int,
send_chunk: Callable[[transfer_pb2.Chunk], None],
@@ -105,6 +131,10 @@
self._send_chunk = send_chunk
self._end_transfer = end_transfer
+ # TODO(b/243679452): Only the legacy protocol begins in a WAITING state.
+ # Newer protocols should start as INITIATING.
+ self._state = Transfer._State.WAITING
+
self._retries = 0
self._max_retries = max_retries
self._response_timer = _Timer(response_timeout_s, self._on_timeout)
@@ -158,7 +188,7 @@
def _on_timeout(self) -> None:
"""Handles a timeout while waiting for a chunk."""
- if self.done.is_set():
+ if self._state is Transfer._State.COMPLETE:
return
self._retries += 1
@@ -185,6 +215,7 @@
self._end_transfer(self)
# Set done last so that the transfer has been fully cleaned up.
+ self._state = Transfer._State.COMPLETE
self.done.set()
def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
@@ -225,9 +256,6 @@
max_retries, progress_callback)
self._data = data
- # Guard this class with a lock since a transfer parameters update might
- # arrive while responding to a prior update.
- self._lock = asyncio.Lock()
self._offset = 0
self._window_end_offset = 0
self._max_chunk_size = 0
@@ -236,6 +264,7 @@
# The window ID increments for each parameters update.
self._window_id = 0
+ self._bytes_confirmed_received = 0
self._last_chunk = self._initial_chunk()
@property
@@ -243,8 +272,8 @@
return self._data
def _initial_chunk(self) -> transfer_pb2.Chunk:
- # TODO(frolv): session_id should not be set here, but assigned by the
- # server during an initial handshake.
+ # TODO(b/243679452): In protocol v2, session_id should be assigned by
+ # the server during an initial handshake.
return transfer_pb2.Chunk(transfer_id=self.id,
resource_id=self.id,
type=transfer_pb2.Chunk.Type.START)
@@ -257,40 +286,59 @@
send data accordingly.
"""
- async with self._lock:
- self._window_id += 1
- window_id = self._window_id
+ if self._state is Transfer._State.TERMINATING:
+ # TODO(b/243679452): Handle termination in protocol v2.
+ return
- if not self._handle_parameters_update(chunk):
- return
+ if self._state is Transfer._State.TRANSMITTING:
+ self._state = Transfer._State.WAITING
- bytes_acknowledged = chunk.offset
+ assert self._state is Transfer._State.WAITING
- while True:
- if self._chunk_delay_us:
- await asyncio.sleep(self._chunk_delay_us / 1e6)
+ if not self._handle_parameters_update(chunk):
+ return
- async with self._lock:
- if self.done.is_set():
- return
+ self._bytes_confirmed_received = chunk.offset
+ self._state = Transfer._State.TRANSMITTING
- if window_id != self._window_id:
- _LOG.debug('Transfer %d: Skipping stale window', self.id)
- return
+ self._window_id += 1
+ asyncio.create_task(self._transmit_next_chunk(self._window_id))
- write_chunk = self._next_chunk()
- self._offset += len(write_chunk.data)
- sent_requested_bytes = self._offset == self._window_end_offset
+ async def _transmit_next_chunk(self,
+ window_id: int,
+ timeout_us: int = None) -> None:
+ """Transmits a single data chunk to the server.
- self._send_chunk(write_chunk)
+ If the chunk completes the active window, returns to a WAITING state.
+ Otherwise, schedules another transmission for the next chunk.
+ """
+ if timeout_us is not None:
+ await asyncio.sleep(timeout_us / 1e6)
- self._update_progress(self._offset, bytes_acknowledged,
- len(self.data))
+ if self._state is not Transfer._State.TRANSMITTING:
+ return
- if sent_requested_bytes:
- break
+ if window_id != self._window_id:
+ _LOG.debug('Transfer %d: Skipping stale window', self.id)
+ return
- self._last_chunk = write_chunk
+ chunk = self._next_chunk()
+ self._offset += len(chunk.data)
+
+ sent_requested_bytes = self._offset == self._window_end_offset
+
+ self._send_chunk(chunk)
+ self._update_progress(self._offset, self._bytes_confirmed_received,
+ len(self.data))
+
+ if sent_requested_bytes:
+ self._state = Transfer._State.WAITING
+ else:
+ asyncio.create_task(
+ self._transmit_next_chunk(window_id,
+ timeout_us=self._chunk_delay_us))
+
+ self._last_chunk = chunk
def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
"""Updates transfer state based on a transfer parameters update."""
@@ -348,7 +396,10 @@
return True
def _retry_after_timeout(self) -> None:
- self._send_chunk(self._last_chunk)
+ if self._state is Transfer._State.WAITING:
+ self._send_chunk(self._last_chunk)
+
+ # TODO(b/243679452): Handle other states in protocol v2.
def _next_chunk(self) -> transfer_pb2.Chunk:
"""Returns the next Chunk message to send in the data transfer."""
@@ -408,6 +459,7 @@
self._offset = 0
self._pending_bytes = max_bytes_to_receive
self._window_end_offset = max_bytes_to_receive
+ self._last_chunk_offset: Optional[int] = None
@property
def data(self) -> bytes:
@@ -424,10 +476,43 @@
Once all pending data is received, the transfer parameters are updated.
"""
+ if self._state is Transfer._State.TERMINATING:
+ # TODO(b/243679452): Handle terminating states in protocol v2.
+ return
+
+ if self._state is Transfer._State.RECOVERY:
+ if chunk.offset != self._offset:
+ if self._last_chunk_offset == chunk.offset:
+ _LOG.debug(
+ 'Transfer %d received repeated offset %d: '
+ 'retry detected, resending transfer parameters',
+ self.id, chunk.offset)
+ self._send_chunk(
+ self._transfer_parameters(
+ transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+ else:
+ _LOG.debug(
+ 'Transfer %d waiting for offset %d, ignoring %d',
+ self.id, self._offset, chunk.offset)
+ self._last_chunk_offset = chunk.offset
+ return
+
+ _LOG.info(
+ 'Transfer %d received expected offset %d, resuming transfer',
+ self.id, chunk.offset)
+ self._state = Transfer._State.WAITING
+
+ assert self._state is Transfer._State.WAITING
+
if chunk.offset != self._offset:
# Initially, the transfer service only supports in-order transfers.
# If data is received out of order, request that the server
# retransmit from the previous offset.
+ _LOG.debug(
+ 'Transfer %d expected offset %d, received %d: '
+ 'entering recovery state', self.id, self._offset, chunk.offset)
+ self._state = Transfer._State.RECOVERY
+
self._send_chunk(
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
@@ -437,6 +522,9 @@
self._pending_bytes -= len(chunk.data)
self._offset += len(chunk.data)
+ # Update the last offset seen so that retries can be detected.
+ self._last_chunk_offset = chunk.offset
+
if chunk.HasField('remaining_bytes'):
if chunk.remaining_bytes == 0:
# No more data to read. Acknowledge receipt and finish.
@@ -498,9 +586,13 @@
transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))
def _retry_after_timeout(self) -> None:
- self._send_chunk(
- self._transfer_parameters(
- transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+ if (self._state is Transfer._State.WAITING
+ or self._state is Transfer._State.RECOVERY):
+ self._send_chunk(
+ self._transfer_parameters(
+ transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
+
+ # TODO(b/243679452): Handle other states in protocol v2.
def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
"""Sends an updated transfer parameters chunk to the server."""
diff --git a/pw_transfer/py/tests/transfer_test.py b/pw_transfer/py/tests/transfer_test.py
index 63a73ee..1af8c35 100644
--- a/pw_transfer/py/tests/transfer_test.py
+++ b/pw_transfer/py/tests/transfer_test.py
@@ -200,6 +200,68 @@
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
+ def test_read_transfer_recovery_sends_parameters_on_retry(self) -> None:
+ """Server sends the same chunk twice (retry) in a read transfer."""
+ manager = pw_transfer.Manager(
+ self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ (
+ (
+ # Bad offset, enter recovery state. Only one parameters
+ # chunk should be sent.
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=1,
+ data=b'234',
+ remaining_bytes=5),
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=4,
+ data=b'567',
+ remaining_bytes=2),
+ transfer_pb2.Chunk(
+ transfer_id=3, offset=7, data=b'8', remaining_bytes=1),
+ ),
+ (
+ # Only one parameters chunk should be sent after the server
+ # retries the same offset twice.
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=1,
+ data=b'234',
+ remaining_bytes=5),
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=4,
+ data=b'567',
+ remaining_bytes=2),
+ transfer_pb2.Chunk(
+ transfer_id=3,
+ offset=7, data=b'8', remaining_bytes=1),
+ transfer_pb2.Chunk(
+ transfer_id=3,
+ offset=7, data=b'8', remaining_bytes=1),
+ ),
+ (transfer_pb2.Chunk(transfer_id=3,
+ offset=0,
+ data=b'123456789',
+ remaining_bytes=0), ),
+ ))
+
+ data = manager.read(3)
+ self.assertEqual(data, b'123456789')
+
+ self.assertEqual(len(self._sent_chunks), 4)
+ self.assertEqual(self._sent_chunks[0].type,
+ transfer_pb2.Chunk.Type.START)
+ self.assertEqual(self._sent_chunks[0].offset, 0)
+ self.assertEqual(self._sent_chunks[1].type,
+ transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)
+ self.assertEqual(self._sent_chunks[1].offset, 0)
+ self.assertEqual(self._sent_chunks[2].type,
+ transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT)
+ self.assertEqual(self._sent_chunks[2].offset, 0)
+ self.assertEqual(self._sent_chunks[3].type,
+ transfer_pb2.Chunk.Type.COMPLETION)
+
def test_read_transfer_retry_timeout(self) -> None:
"""Server doesn't respond to read transfer parameters."""
manager = pw_transfer.Manager(