pw_transfer: Initial Python implementation
This implements a basic version of the client-side pw_transfer protocol
in Python (theoretically). Both read and write transfers are supported.
Not all possible error cases are handled yet.
Change-Id: I96d56251951ec5c17fae67bf5eb623d2e3f5b7dc
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/50280
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_env_setup/BUILD.gn b/pw_env_setup/BUILD.gn
index d78567b..ace9db5 100644
--- a/pw_env_setup/BUILD.gn
+++ b/pw_env_setup/BUILD.gn
@@ -52,6 +52,7 @@
"$dir_pw_toolchain/py",
"$dir_pw_trace/py",
"$dir_pw_trace_tokenized/py",
+ "$dir_pw_transfer/py",
"$dir_pw_unit_test/py",
"$dir_pw_watch/py",
diff --git a/pw_transfer/BUILD.gn b/pw_transfer/BUILD.gn
index f7682d1..55fa7b2 100644
--- a/pw_transfer/BUILD.gn
+++ b/pw_transfer/BUILD.gn
@@ -37,6 +37,8 @@
pw_proto_library("proto") {
sources = [ "transfer.proto" ]
+ python_package = "py"
+ prefix = "pw_transfer"
}
pw_test_group("tests") {
diff --git a/pw_transfer/docs.rst b/pw_transfer/docs.rst
index daf559c..193a538 100644
--- a/pw_transfer/docs.rst
+++ b/pw_transfer/docs.rst
@@ -8,6 +8,41 @@
``pw_transfer`` is under construction and so is its documentation.
+-----
+Usage
+-----
+
+Python
+======
+.. automodule:: pw_transfer.transfer
+ :members: Manager, Error
+
+**Example**
+
+.. code-block:: python
+
+ from pw_transfer import transfer
+
+ # Initialize a Pigweed RPC client; see pw_rpc docs for more info.
+ rpc_client = CustomRpcClient()
+ rpcs = rpc_client.channel(1).rpcs
+
+ transfer_service = rpcs.pw.transfer.Transfer
+ transfer_manager = transfer.Manager(transfer_service)
+
+ try:
+ # Read transfer_id 3 from the server.
+ data = transfer_manager.read(3)
+ except transfer.Error as err:
+ print('Failed to read:', err.status)
+
+ try:
+ # Send some data to the server. The transfer manager does not have to be
+ # reinitialized.
+ transfer_manager.write(2, b'hello, world')
+ except transfer.Error as err:
+ print('Failed to write:', err.status)
+
--------
Protocol
--------
diff --git a/pw_transfer/py/BUILD.gn b/pw_transfer/py/BUILD.gn
new file mode 100644
index 0000000..a5bf8ba
--- /dev/null
+++ b/pw_transfer/py/BUILD.gn
@@ -0,0 +1,34 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+import("//build_overrides/pigweed.gni")
+
+import("$dir_pw_build/python.gni")
+import("$dir_pw_docgen/docs.gni")
+
+pw_python_package("py") {
+ generate_setup = {
+ name = "pw_transfer"
+ version = "0.0.1"
+ }
+ sources = [ "pw_transfer/transfer.py" ]
+ tests = [ "transfer_test.py" ]
+ python_deps = [
+ "$dir_pw_rpc/py",
+ "$dir_pw_status/py",
+ ]
+ python_test_deps = [ "$dir_pw_build/py" ]
+ pylintrc = "$dir_pigweed/.pylintrc"
+ proto_library = "..:proto"
+}
diff --git a/pw_transfer/py/pw_transfer/transfer.py b/pw_transfer/py/pw_transfer/transfer.py
new file mode 100644
index 0000000..46e3f7b
--- /dev/null
+++ b/pw_transfer/py/pw_transfer/transfer.py
@@ -0,0 +1,549 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Provides a simple interface for transferring bulk data over pw_rpc."""
+
+import abc
+import asyncio
+import logging
+import threading
+from typing import Any, Callable, Dict, Optional
+
+from pw_rpc.callback_client import BidirectionalStream
+from pw_status import Status
+from pw_transfer.transfer_pb2 import Chunk
+
+_LOG = logging.getLogger(__name__)
+
+
+class _Timer:
+ """A timer which invokes a callback after a certain timeout."""
+ def __init__(self, callback: Callable[[], Any]):
+ self._callback = callback
+ self._task: Optional[asyncio.Task[Any]] = None
+
+ def start(self, timeout_s: float):
+ """Starts a new timer.
+
+ If a timer is already running, it is stopped and a new timer started.
+ This can be used to implement watchdog-like behavior, where a callback
+ is invoked after some time without a kick.
+ """
+ self.stop()
+ self._task = asyncio.create_task(self._run(timeout_s))
+
+ def stop(self):
+ """Terminates a running timer."""
+ if self._task is not None:
+ self._task.cancel()
+ self._task = None
+
+ async def _run(self, timeout_s: float):
+ await asyncio.sleep(timeout_s)
+ self._callback()
+ self._task = None
+
+
+class _Transfer(abc.ABC):
+ """A client-side data transfer through a Manager.
+
+ Subclasses are responsible for implementing all of the logic for their type
+ of transfer, receiving messages from the server and sending the appropriate
+ messages in response.
+ """
+ def __init__(self, transfer_id: int, data: bytes,
+ send_chunk: Callable[[Chunk], None],
+ end_transfer: Callable[['_Transfer', Status], None]):
+ self.id = transfer_id
+ self.status = Status.OK
+ self.data = data
+ self.done = threading.Event()
+ self._send_chunk = send_chunk
+ self._end_transfer = end_transfer
+
+ @abc.abstractmethod
+ async def begin(self) -> None:
+ """Sends the initial chunk to notify the sever of the transfer."""
+
+ @abc.abstractmethod
+ async def handle_chunk(self, chunk: Chunk) -> bool:
+ """Processes an incoming chunk from the server.
+
+ Returns True if the transfer is complete or False otherwise.
+
+ Only called for non-terminating chunks (i.e. those without a status).
+ """
+
+ def _send_error(self, error: Status) -> None:
+ """Sends an error chunk to the server. Should only be called once."""
+ self.status = error
+ self._send_chunk(Chunk(transfer_id=self.id, status=error.value))
+
+
+class _WriteTransfer(_Transfer):
+ """A client -> server write transfer."""
+ def __init__(
+ self,
+ transfer_id: int,
+ data: bytes,
+ send_chunk: Callable[[Chunk], None],
+ end_transfer: Callable[[_Transfer, Status], None],
+ response_timeout_s: float,
+ ):
+ super().__init__(transfer_id, data, send_chunk, end_transfer)
+ self._offset = 0
+ self._response_timeout_s = response_timeout_s
+ self._response_timer = _Timer(
+ lambda: self._end_transfer(self, Status.DEADLINE_EXCEEDED))
+
+ self._max_bytes_to_send = 0
+ self._max_chunk_size = 0
+ self._chunk_delay_us: Optional[int] = None
+
+ async def begin(self) -> None:
+ """Sends the transfer ID, notifying the server that we want to write."""
+ self._send_chunk(Chunk(transfer_id=self.id))
+ self._response_timer.start(self._response_timeout_s)
+
+ async def handle_chunk(self, chunk: Chunk) -> bool:
+ """Processes an incoming chunk from the server.
+
+ In a write transfer, the server only sends transfer parameter updates
+ to the client. When a message is received, update local parameters and
+ send data accordingly.
+ """
+
+ self._response_timer.stop()
+
+ # Check whether the client has sent a previous data offset, which
+ # indicates that some chunks were lost in transmission.
+ if chunk.offset < self._offset:
+ _LOG.debug('Write transfer %d rolling back to offset %d from %d',
+ self.id, chunk.offset, self._offset)
+
+ self._offset = chunk.offset
+
+ if self._offset > len(self.data):
+ # Bad offset; terminate the transfer.
+ _LOG.error(
+ 'Transfer %d: server requested invalid offset %d (size %d)',
+ self.id, self._offset, len(self.data))
+
+ self._send_error(Status.OUT_OF_RANGE)
+ return True
+
+ self._max_bytes_to_send = min(chunk.pending_bytes,
+ len(self.data) - self._offset)
+
+ if chunk.HasField('max_chunk_size_bytes'):
+ self._max_chunk_size = chunk.max_chunk_size_bytes
+ if chunk.HasField('min_delay_microseconds'):
+ self._chunk_delay_us = chunk.min_delay_microseconds
+
+ while self._max_bytes_to_send > 0:
+ write_chunk = self._next_chunk()
+ self._offset += len(write_chunk.data)
+ self._max_bytes_to_send -= len(write_chunk.data)
+
+ self._send_chunk(write_chunk)
+
+ if self._chunk_delay_us:
+ await asyncio.sleep(self._chunk_delay_us / 1e6)
+
+ self._response_timer.start(self._response_timeout_s)
+ return False
+
+ def _next_chunk(self) -> Chunk:
+ """Returns the next Chunk message to send in the data transfer."""
+ chunk = Chunk(transfer_id=self.id, offset=self._offset)
+
+ if len(self.data) - self._offset <= self._max_chunk_size:
+ # Final chunk of the transfer.
+ chunk.data = self.data[self._offset:]
+ chunk.remaining_bytes = 0
+ else:
+ chunk.data = self.data[self._offset:self._offset +
+ self._max_chunk_size]
+
+ return chunk
+
+
+class _ReadTransfer(_Transfer):
+ """A client <- server read transfer.
+
+ Although Python can effectively handle an unlimited transfer window, this
+ client sets a conservative window and chunk size to avoid overloading the
+ device. These are configurable in the constructor.
+ """
+ def __init__(self,
+ transfer_id: int,
+ send_chunk: Callable[[Chunk], None],
+ end_transfer: Callable[[_Transfer, Status], None],
+ response_timeout_s: float,
+ max_retries: int = 3,
+ max_bytes_to_receive: int = 8192,
+ max_chunk_size: int = 1024,
+ chunk_delay_us: int = None):
+ super().__init__(transfer_id, bytes(), send_chunk, end_transfer)
+
+ self._response_timeout_s = response_timeout_s
+ self._response_timer = _Timer(self._on_timeout)
+ self._chunk_timeout_count = 0
+ self._max_retries = max_retries
+
+ self._max_bytes_to_receive = max_bytes_to_receive
+ self._max_chunk_size = max_chunk_size
+ self._chunk_delay_us = chunk_delay_us
+
+ self._remaining_transfer_size: Optional[int] = None
+ self._offset = 0
+ self._pending_bytes = max_bytes_to_receive
+
+ async def begin(self) -> None:
+ """Sends the initial transfer parameters for the read transfer."""
+ self._send_transfer_parameters()
+
+ async def handle_chunk(self, chunk: Chunk) -> bool:
+ """Processes an incoming chunk from the server.
+
+ In a read transfer, the client receives data chunks from the server.
+ Once all pending data is received, the transfer parameters are updated.
+ """
+
+ self._response_timer.stop()
+ self._chunk_timeout_count = 0
+
+ if chunk.offset != self._offset:
+ # Initially, the transfer service only supports in-order transfers.
+ # If data is received out of order, request that the server
+ # retransmit from the previous offset.
+ self._pending_bytes = 0
+ self._send_transfer_parameters()
+ return False
+
+ self.data += chunk.data
+ self._pending_bytes -= len(chunk.data)
+ self._offset += len(chunk.data)
+
+ if chunk.HasField('remaining_bytes'):
+ if chunk.remaining_bytes == 0:
+ # No more data to read. Acknowledge receipt and finish.
+ self._send_chunk(
+ Chunk(transfer_id=self.id, status=Status.OK.value))
+ return True
+
+ # The server may optionally indicate that it has a known size of
+ # data available. This is not yet used.
+ self._remaining_transfer_size = chunk.remaining_bytes
+
+ if self._pending_bytes == 0:
+ # All pending data was received. Send out a new parameters chunk for
+ # the next block.
+ # _send_transfer_parameters starts the response timer.
+ self._send_transfer_parameters()
+ else:
+ # Wait for the next pending chunk.
+ self._response_timer.start(self._response_timeout_s)
+
+ return False
+
+ def _send_transfer_parameters(self):
+ """Sends an updated transfer parameters chunk to the server."""
+
+ self._pending_bytes = self._max_bytes_to_receive
+
+ chunk = Chunk(transfer_id=self.id,
+ pending_bytes=self._pending_bytes,
+ max_chunk_size_bytes=self._max_chunk_size,
+ offset=self._offset)
+
+ if self._chunk_delay_us:
+ chunk.min_delay_microseconds = self._chunk_delay_us
+
+ self._send_chunk(chunk)
+
+ # Start the timeout for the server to send a read chunk.
+ self._response_timer.start(self._response_timeout_s)
+
+ def _on_timeout(self) -> None:
+ """Handles a timeout while waiting for a chunk.
+
+ In a read transfer, if a chunk is not received by the timeout, the
+ receiver resends the latest transfer parameters to prompt the server
+ for more data. This is done several times, up to a specified number of
+ retries, after which the transfer is terminated.
+ """
+ self._chunk_timeout_count += 1
+
+ if self._chunk_timeout_count > self._max_retries:
+ self._end_transfer(self, Status.DEADLINE_EXCEEDED)
+ else:
+ self._send_transfer_parameters()
+
+
+_TransferDict = Dict[int, _Transfer]
+
+
+class Manager: # pylint: disable=too-many-instance-attributes
+ """A manager for transmitting data through an RPC TransferService.
+
+ This should be initialized with an active Manager over an RPC channel. Only
+ one instance of this class should exist for a configured RPC TransferService
+ -- the Manager supports multiple simultaneous transfers.
+
+ When created, a Manager starts a separate thread in which transfer
+ communications and events are handled.
+ """
+ def __init__(self,
+ rpc_transfer_service,
+ default_response_timeout_s: float = 2.0):
+ """Initializes a Manager on top of a TransferService."""
+ self._service: Any = rpc_transfer_service
+ self._default_response_timeout_s = default_response_timeout_s
+
+ # Ongoing transfers in the service by ID.
+ self._read_transfers: _TransferDict = {}
+ self._write_transfers: _TransferDict = {}
+
+ # RPC streams for read and write transfers. These are shareable by
+ # multiple transfers of the same type.
+ self._read_stream: Optional[BidirectionalStream] = None
+ self._write_stream: Optional[BidirectionalStream] = None
+
+ self._loop = asyncio.new_event_loop()
+
+ # Queues are used for communication between the Manager context and the
+ # dedicated asyncio transfer thread.
+ self._new_transfer_queue: asyncio.Queue = asyncio.Queue()
+ self._read_chunk_queue: asyncio.Queue = asyncio.Queue()
+ self._write_chunk_queue: asyncio.Queue = asyncio.Queue()
+ self._quit_event = asyncio.Event()
+
+ self._thread = threading.Thread(target=self._start_event_loop_thread,
+ daemon=True)
+
+ self._thread.start()
+
+ def __del__(self):
+ # Notify the thread that the transfer manager is being destroyed and
+ # wait for it to exit.
+ if self._thread.is_alive():
+ self._loop.call_soon_threadsafe(self._quit_event.set)
+ self._thread.join()
+
+ def read(self, transfer_id: int) -> bytes:
+ """Receives ("downloads") data from the server.
+
+ Raises:
+ Error: the transfer failed to complete
+ """
+
+ if transfer_id in self._read_transfers:
+ raise ValueError(f'Read transfer {transfer_id} already exists')
+
+ transfer = _ReadTransfer(transfer_id, self._send_read_chunk,
+ self._end_read_transfer,
+ self._default_response_timeout_s)
+ self._start_read_transfer(transfer)
+
+ transfer.done.wait()
+
+ if not transfer.status.ok():
+ raise Error(transfer.id, transfer.status)
+
+ return transfer.data
+
+ def write(self, transfer_id: int, data: bytes) -> None:
+ """Transmits ("uploads") data to the server.
+
+ Raises:
+ Error: the transfer failed to complete
+ """
+
+ if transfer_id in self._write_transfers:
+ raise ValueError(f'Write transfer {transfer_id} already exists')
+
+ transfer = _WriteTransfer(transfer_id, data, self._send_write_chunk,
+ self._end_write_transfer,
+ self._default_response_timeout_s)
+ self._start_write_transfer(transfer)
+
+ transfer.done.wait()
+
+ if not transfer.status.ok():
+ raise Error(transfer.id, transfer.status)
+
+ def _send_read_chunk(self, chunk: Chunk) -> None:
+ assert self._read_stream is not None
+ self._read_stream.send(chunk)
+
+ def _send_write_chunk(self, chunk: Chunk) -> None:
+ assert self._write_stream is not None
+ self._write_stream.send(chunk)
+
+ def _start_event_loop_thread(self):
+ """Entry point for event loop thread that starts an asyncio context."""
+ asyncio.set_event_loop(self._loop)
+
+ # Recreate the async communication channels in the context of the
+ # running event loop.
+ self._new_transfer_queue = asyncio.Queue()
+ self._read_chunk_queue = asyncio.Queue()
+ self._write_chunk_queue = asyncio.Queue()
+ self._quit_event = asyncio.Event()
+
+ self._loop.create_task(self._transfer_event_loop())
+ self._loop.run_forever()
+
+ async def _transfer_event_loop(self):
+ """Main async event loop."""
+ exit_thread = self._loop.create_task(self._quit_event.wait())
+ new_transfer = self._loop.create_task(self._new_transfer_queue.get())
+ read_chunk = self._loop.create_task(self._read_chunk_queue.get())
+ write_chunk = self._loop.create_task(self._write_chunk_queue.get())
+
+ while not self._quit_event.is_set():
+ # Perform a select(2)-like wait for one of several events to occur.
+ done, _ = await asyncio.wait(
+ (exit_thread, new_transfer, read_chunk, write_chunk),
+ return_when=asyncio.FIRST_COMPLETED)
+
+ if exit_thread in done:
+ break
+
+ if new_transfer in done:
+ await new_transfer.result().begin()
+ new_transfer = self._loop.create_task(
+ self._new_transfer_queue.get())
+
+ if read_chunk in done:
+ self._loop.create_task(
+ self._handle_chunk(self._read_transfers,
+ self._end_read_transfer,
+ read_chunk.result()))
+ read_chunk = self._loop.create_task(
+ self._read_chunk_queue.get())
+
+ if write_chunk in done:
+ self._loop.create_task(
+ self._handle_chunk(self._write_transfers,
+ self._end_write_transfer,
+ write_chunk.result()))
+ write_chunk = self._loop.create_task(
+ self._write_chunk_queue.get())
+
+ self._loop.stop()
+
+ @staticmethod
+ async def _handle_chunk(
+ transfers: _TransferDict,
+ completion: Callable[[_Transfer, Status], None],
+ chunk: Chunk,
+ ) -> None:
+ """Processes an incoming chunk from a stream.
+
+ The chunk is dispatched to an active transfer based on its ID. If the
+ transfer indicates that it is complete, the provided completion callback
+ is invoked.
+ """
+
+ try:
+ transfer = transfers[chunk.transfer_id]
+ except KeyError:
+ _LOG.error(
+ 'TransferManager received chunk for unknown transfer %d',
+ chunk.transfer_id)
+ # TODO(frolv): What should be done here, if anything?
+ return
+
+ should_exit = False
+
+ # Status chunks are only used to terminate a transfer. They do not
+ # contain any data that requires processing.
+ if not chunk.HasField('status'):
+ should_exit = await transfer.handle_chunk(chunk)
+
+ if should_exit or chunk.HasField('status'):
+ completion(transfer, Status(chunk.status))
+
+ def _start_read_transfer(self, transfer: _Transfer) -> None:
+ """Begins a new read transfer, opening the stream if it isn't."""
+
+ self._read_transfers[transfer.id] = transfer
+
+ if not self._read_stream:
+ self._read_stream = self._service.Read(
+ lambda _, chunk: self._loop.call_soon_threadsafe(
+ self._read_chunk_queue.put_nowait, chunk))
+
+ _LOG.debug('Starting new read transfer %d', transfer.id)
+ self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait,
+ transfer)
+
+ def _end_read_transfer(self, transfer: _Transfer, status: Status) -> None:
+ """Completes a read transfer."""
+ del self._read_transfers[transfer.id]
+
+ if not status.ok():
+ _LOG.error('Read transfer %d terminated with status %s',
+ transfer.id, status)
+ transfer.status = status
+
+ # If no more transfers are using the read stream, close it.
+ if not self._read_transfers and self._read_stream:
+ self._read_stream.cancel()
+ self._read_stream = None
+
+ transfer.done.set()
+
+ def _start_write_transfer(self, transfer: _Transfer) -> None:
+ """Begins a new write transfer, opening the stream if it isn't."""
+
+ self._write_transfers[transfer.id] = transfer
+
+ if not self._write_stream:
+ # TODO(frolv): Make the actual RPC call.
+ self._write_stream = self._service.Write(
+ lambda _, chunk: self._loop.call_soon_threadsafe(
+ self._write_chunk_queue.put_nowait, chunk))
+
+ _LOG.debug('Starting new write transfer %d', transfer.id)
+ self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait,
+ transfer)
+
+ def _end_write_transfer(self, transfer: _Transfer, status: Status) -> None:
+ """Completes a write transfer."""
+ del self._write_transfers[transfer.id]
+
+ if not status.ok():
+ _LOG.error('Write transfer %d terminated with status %s',
+ transfer.id, status)
+ transfer.status = status
+
+ # If no more transfers are using the write stream, close it.
+ if not self._write_transfers and self._write_stream:
+ self._write_stream.cancel()
+ self._write_stream = None
+
+ transfer.done.set()
+
+
+class Error(Exception):
+ """Exception raised when a transfer fails.
+
+ Stores the ID of the failed transfer and the error that occurred.
+ """
+ def __init__(self, transfer_id: int, status: Status):
+ super().__init__(f'Transfer {transfer_id} failed with status {status}')
+ self.transfer_id = transfer_id
+ self.status = status
diff --git a/pw_transfer/py/transfer_test.py b/pw_transfer/py/transfer_test.py
new file mode 100644
index 0000000..ff519c2
--- /dev/null
+++ b/pw_transfer/py/transfer_test.py
@@ -0,0 +1,377 @@
+#!/usr/bin/env python3
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests for the transfer service client."""
+
+import enum
+import unittest
+from typing import Iterable, List
+
+from pw_status import Status
+from pw_rpc import callback_client, client, ids, packets
+from pw_rpc.internal import packet_pb2
+
+from pw_transfer import transfer_pb2
+from pw_transfer import transfer
+
+_TRANSFER_SERVICE_ID = ids.calculate('pw.transfer.Transfer')
+
+
+class _Method(enum.Enum):
+ READ = ids.calculate('Read')
+ WRITE = ids.calculate('Write')
+
+
+class TransferManagerTest(unittest.TestCase):
+ """Tests for the transfer manager."""
+ def setUp(self):
+ self._client = client.Client.from_modules(
+ callback_client.Impl(), [client.Channel(1, self._handle_request)],
+ (transfer_pb2, ))
+ self._service = self._client.channel(1).rpcs.pw.transfer.Transfer
+
+ self._sent_chunks: List[transfer_pb2.Chunk] = []
+ self._packets_to_send: List[List[bytes]] = []
+
+ def _enqueue_server_responses(
+ self, method: _Method,
+ responses: Iterable[Iterable[transfer_pb2.Chunk]]) -> None:
+ for group in responses:
+ serialized_group = []
+ for response in group:
+ serialized_group.append(
+ packet_pb2.RpcPacket(
+ type=packet_pb2.PacketType.SERVER_STREAM,
+ channel_id=1,
+ service_id=_TRANSFER_SERVICE_ID,
+ method_id=method.value,
+ status=Status.OK.value,
+ payload=response.SerializeToString()).
+ SerializeToString())
+ self._packets_to_send.append(serialized_group)
+
+ def _handle_request(self, data: bytes) -> None:
+ packet = packets.decode(data)
+ if packet.type is not packet_pb2.PacketType.CLIENT_STREAM:
+ return
+
+ chunk = transfer_pb2.Chunk()
+ chunk.MergeFromString(packet.payload)
+ self._sent_chunks.append(chunk)
+
+ if self._packets_to_send:
+ responses = self._packets_to_send.pop(0)
+ for response in responses:
+ self._client.process_packet(response)
+
+ def _received_data(self) -> bytearray:
+ data = bytearray()
+ for chunk in self._sent_chunks:
+ data.extend(chunk.data)
+ return data
+
+ def test_read_transfer_basic(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ ((transfer_pb2.Chunk(
+ transfer_id=3, offset=0, data=b'abc', remaining_bytes=0), ), ),
+ )
+
+ data = manager.read(3)
+ self.assertEqual(data, b'abc')
+ self.assertEqual(len(self._sent_chunks), 2)
+ self.assertTrue(self._sent_chunks[-1].HasField('status'))
+ self.assertEqual(self._sent_chunks[-1].status, 0)
+
+ def test_read_transfer_multichunk(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ ((
+ transfer_pb2.Chunk(
+ transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
+ transfer_pb2.Chunk(
+ transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
+ ), ),
+ )
+
+ data = manager.read(3)
+ self.assertEqual(data, b'abcdef')
+ self.assertEqual(len(self._sent_chunks), 2)
+ self.assertTrue(self._sent_chunks[-1].HasField('status'))
+ self.assertEqual(self._sent_chunks[-1].status, 0)
+
+ def test_read_transfer_retry_bad_offset(self):
+ """Server responds with an unexpected offset in a read transfer."""
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ (
+ (
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=0,
+ data=b'123',
+ remaining_bytes=6),
+
+ # Incorrect offset; expecting 3.
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=1,
+ data=b'456',
+ remaining_bytes=3),
+ ),
+ (
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=3,
+ data=b'456',
+ remaining_bytes=3),
+ transfer_pb2.Chunk(transfer_id=3,
+ offset=6,
+ data=b'789',
+ remaining_bytes=0),
+ ),
+ ))
+
+ data = manager.read(3)
+ self.assertEqual(data, b'123456789')
+
+ # Two transfer parameter requests should have been sent.
+ self.assertEqual(len(self._sent_chunks), 3)
+ self.assertTrue(self._sent_chunks[-1].HasField('status'))
+ self.assertEqual(self._sent_chunks[-1].status, 0)
+
+ def test_read_transfer_retry_timeout(self):
+ """Server doesn't respond to read transfer parameters."""
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ (
+ (), # Send nothing in response to the initial parameters.
+ (transfer_pb2.Chunk(
+ transfer_id=3, offset=0, data=b'xyz',
+ remaining_bytes=0), ),
+ ))
+
+ data = manager.read(3)
+ self.assertEqual(data, b'xyz')
+
+ # Two transfer parameter requests should have been sent.
+ self.assertEqual(len(self._sent_chunks), 3)
+ self.assertTrue(self._sent_chunks[-1].HasField('status'))
+ self.assertEqual(self._sent_chunks[-1].status, 0)
+
+ def test_read_transfer_timeout(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ with self.assertRaises(transfer.Error) as context:
+ manager.read(27)
+
+ exception = context.exception
+ self.assertEqual(exception.transfer_id, 27)
+ self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
+
+ # The client should have sent four transfer parameters requests: one
+ # initial, and three retries.
+ self.assertEqual(len(self._sent_chunks), 4)
+
+ def test_read_transfer_error(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.READ,
+ ((transfer_pb2.Chunk(transfer_id=31,
+ status=Status.NOT_FOUND.value), ), ),
+ )
+
+ with self.assertRaises(transfer.Error) as context:
+ manager.read(31)
+
+ exception = context.exception
+ self.assertEqual(exception.transfer_id, 31)
+ self.assertEqual(exception.status, Status.NOT_FOUND)
+
+ def test_write_transfer_basic(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=32,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ manager.write(4, b'hello')
+ self.assertEqual(len(self._sent_chunks), 2)
+ self.assertEqual(self._received_data(), b'hello')
+
+ def test_write_transfer_max_chunk_size(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=32,
+ max_chunk_size_bytes=8), ),
+ (),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ manager.write(4, b'hello world')
+ self.assertEqual(len(self._sent_chunks), 3)
+ self.assertEqual(self._received_data(), b'hello world')
+ self.assertEqual(self._sent_chunks[1].data, b'hello wo')
+ self.assertEqual(self._sent_chunks[2].data, b'rld')
+
+ def test_write_transfer_multiple_parameters(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=8,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ manager.write(4, b'data to write')
+ self.assertEqual(len(self._sent_chunks), 3)
+ self.assertEqual(self._received_data(), b'data to write')
+ self.assertEqual(self._sent_chunks[1].data, b'data to ')
+ self.assertEqual(self._sent_chunks[2].data, b'write')
+
+ def test_write_transfer_rewind(self):
+ """Write transfer in which the server re-requests an earlier offset."""
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=8,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (
+ transfer_pb2.Chunk(
+ transfer_id=4,
+ offset=4, # rewind
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (
+ transfer_pb2.Chunk(
+ transfer_id=4,
+ offset=12,
+ pending_bytes=16, # update max size
+ max_chunk_size_bytes=16), ),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ manager.write(4, b'pigweed data transfer')
+ self.assertEqual(len(self._sent_chunks), 5)
+ self.assertEqual(self._sent_chunks[1].data, b'pigweed ')
+ self.assertEqual(self._sent_chunks[2].data, b'data tra')
+ self.assertEqual(self._sent_chunks[3].data, b'eed data')
+ self.assertEqual(self._sent_chunks[4].data, b' transfer')
+
+ def test_write_transfer_bad_offset(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ (
+ (transfer_pb2.Chunk(transfer_id=4,
+ offset=0,
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (
+ transfer_pb2.Chunk(
+ transfer_id=4,
+ offset=100, # larger offset than data
+ pending_bytes=8,
+ max_chunk_size_bytes=8), ),
+ (transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
+ ),
+ )
+
+ with self.assertRaises(transfer.Error) as context:
+ manager.write(4, b'small data')
+
+ exception = context.exception
+ self.assertEqual(exception.transfer_id, 4)
+ self.assertEqual(exception.status, Status.OUT_OF_RANGE)
+
+ def test_write_transfer_error(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ self._enqueue_server_responses(
+ _Method.WRITE,
+ ((transfer_pb2.Chunk(transfer_id=21,
+ status=Status.UNAVAILABLE.value), ), ),
+ )
+
+ with self.assertRaises(transfer.Error) as context:
+ manager.write(21, b'no write')
+
+ exception = context.exception
+ self.assertEqual(exception.transfer_id, 21)
+ self.assertEqual(exception.status, Status.UNAVAILABLE)
+
+ def test_write_transfer_timeout(self):
+ manager = transfer.Manager(self._service,
+ default_response_timeout_s=0.3)
+
+ with self.assertRaises(transfer.Error) as context:
+ manager.write(22, b'no server response!')
+
+ exception = context.exception
+ self.assertEqual(exception.transfer_id, 22)
+ self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
+
+
+if __name__ == '__main__':
+ unittest.main()