blob: 63a73ee46098cdcda3ee140f795688ac89181eed [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests for the transfer service client."""
import enum
import math
import unittest
from typing import Iterable, List
from pw_status import Status
from pw_rpc import callback_client, client, ids, packets
from pw_rpc.internal import packet_pb2
import pw_transfer
try:
from pw_transfer import transfer_pb2
except ImportError:
# For the bazel build, which puts generated protos in a different location.
from pigweed.pw_transfer import transfer_pb2 # type: ignore
_TRANSFER_SERVICE_ID = ids.calculate('pw.transfer.Transfer')
# If the default timeout is too short, some tests become flaky on Windows.
DEFAULT_TIMEOUT_S = 0.3
class _Method(enum.Enum):
READ = ids.calculate('Read')
WRITE = ids.calculate('Write')
class TransferManagerTest(unittest.TestCase):
"""Tests for the transfer manager."""
def setUp(self) -> None:
self._client = client.Client.from_modules(
callback_client.Impl(), [client.Channel(1, self._handle_request)],
(transfer_pb2, ))
self._service = self._client.channel(1).rpcs.pw.transfer.Transfer
self._sent_chunks: List[transfer_pb2.Chunk] = []
self._packets_to_send: List[List[bytes]] = []
def _enqueue_server_responses(
self, method: _Method,
responses: Iterable[Iterable[transfer_pb2.Chunk]]) -> None:
for group in responses:
serialized_group = []
for response in group:
serialized_group.append(
packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_STREAM,
channel_id=1,
service_id=_TRANSFER_SERVICE_ID,
method_id=method.value,
status=Status.OK.value,
payload=response.SerializeToString()).
SerializeToString())
self._packets_to_send.append(serialized_group)
def _enqueue_server_error(self, method: _Method, error: Status) -> None:
self._packets_to_send.append([
packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
channel_id=1,
service_id=_TRANSFER_SERVICE_ID,
method_id=method.value,
status=error.value).SerializeToString()
])
def _handle_request(self, data: bytes) -> None:
packet = packets.decode(data)
if packet.type is not packet_pb2.PacketType.CLIENT_STREAM:
return
chunk = transfer_pb2.Chunk()
chunk.MergeFromString(packet.payload)
self._sent_chunks.append(chunk)
if self._packets_to_send:
responses = self._packets_to_send.pop(0)
for response in responses:
self._client.process_packet(response)
def _received_data(self) -> bytearray:
data = bytearray()
for chunk in self._sent_chunks:
data.extend(chunk.data)
return data
def test_read_transfer_basic(self):
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
((transfer_pb2.Chunk(
transfer_id=3, offset=0, data=b'abc', remaining_bytes=0), ), ),
)
data = manager.read(3)
self.assertEqual(data, b'abc')
self.assertEqual(len(self._sent_chunks), 2)
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
def test_read_transfer_multichunk(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
((
transfer_pb2.Chunk(
transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
transfer_pb2.Chunk(
transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
), ),
)
data = manager.read(3)
self.assertEqual(data, b'abcdef')
self.assertEqual(len(self._sent_chunks), 2)
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
def test_read_transfer_progress_callback(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
((
transfer_pb2.Chunk(
transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
transfer_pb2.Chunk(
transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
), ),
)
progress: List[pw_transfer.ProgressStats] = []
data = manager.read(3, progress.append)
self.assertEqual(data, b'abcdef')
self.assertEqual(len(self._sent_chunks), 2)
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
self.assertEqual(progress, [
pw_transfer.ProgressStats(3, 3, 6),
pw_transfer.ProgressStats(6, 6, 6),
])
def test_read_transfer_retry_bad_offset(self) -> None:
"""Server responds with an unexpected offset in a read transfer."""
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
(
(
transfer_pb2.Chunk(transfer_id=3,
offset=0,
data=b'123',
remaining_bytes=6),
# Incorrect offset; expecting 3.
transfer_pb2.Chunk(transfer_id=3,
offset=1,
data=b'456',
remaining_bytes=3),
),
(
transfer_pb2.Chunk(transfer_id=3,
offset=3,
data=b'456',
remaining_bytes=3),
transfer_pb2.Chunk(transfer_id=3,
offset=6,
data=b'789',
remaining_bytes=0),
),
))
data = manager.read(3)
self.assertEqual(data, b'123456789')
# Two transfer parameter requests should have been sent.
self.assertEqual(len(self._sent_chunks), 3)
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
def test_read_transfer_retry_timeout(self) -> None:
"""Server doesn't respond to read transfer parameters."""
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
(
(), # Send nothing in response to the initial parameters.
(transfer_pb2.Chunk(
transfer_id=3, offset=0, data=b'xyz',
remaining_bytes=0), ),
))
data = manager.read(3)
self.assertEqual(data, b'xyz')
# Two transfer parameter requests should have been sent.
self.assertEqual(len(self._sent_chunks), 3)
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)
def test_read_transfer_timeout(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
with self.assertRaises(pw_transfer.Error) as context:
manager.read(27)
exception = context.exception
self.assertEqual(exception.session_id, 27)
self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
# The client should have sent four transfer parameters requests: one
# initial, and three retries.
self.assertEqual(len(self._sent_chunks), 4)
def test_read_transfer_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.READ,
((transfer_pb2.Chunk(transfer_id=31,
status=Status.NOT_FOUND.value), ), ),
)
with self.assertRaises(pw_transfer.Error) as context:
manager.read(31)
exception = context.exception
self.assertEqual(exception.session_id, 31)
self.assertEqual(exception.status, Status.NOT_FOUND)
def test_read_transfer_server_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_error(_Method.READ, Status.NOT_FOUND)
with self.assertRaises(pw_transfer.Error) as context:
manager.read(31)
exception = context.exception
self.assertEqual(exception.session_id, 31)
self.assertEqual(exception.status, Status.INTERNAL)
def test_write_transfer_basic(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=32,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
manager.write(4, b'hello')
self.assertEqual(len(self._sent_chunks), 2)
self.assertEqual(self._received_data(), b'hello')
def test_write_transfer_max_chunk_size(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=32,
max_chunk_size_bytes=8), ),
(),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
manager.write(4, b'hello world')
self.assertEqual(len(self._sent_chunks), 3)
self.assertEqual(self._received_data(), b'hello world')
self.assertEqual(self._sent_chunks[1].data, b'hello wo')
self.assertEqual(self._sent_chunks[2].data, b'rld')
def test_write_transfer_multiple_parameters(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4,
offset=8,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
manager.write(4, b'data to write')
self.assertEqual(len(self._sent_chunks), 3)
self.assertEqual(self._received_data(), b'data to write')
self.assertEqual(self._sent_chunks[1].data, b'data to ')
self.assertEqual(self._sent_chunks[2].data, b'write')
def test_write_transfer_progress_callback(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4,
offset=8,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
progress: List[pw_transfer.ProgressStats] = []
manager.write(4, b'data to write', progress.append)
self.assertEqual(len(self._sent_chunks), 3)
self.assertEqual(self._received_data(), b'data to write')
self.assertEqual(self._sent_chunks[1].data, b'data to ')
self.assertEqual(self._sent_chunks[2].data, b'write')
self.assertEqual(progress, [
pw_transfer.ProgressStats(8, 0, 13),
pw_transfer.ProgressStats(13, 8, 13),
pw_transfer.ProgressStats(13, 13, 13)
])
def test_write_transfer_rewind(self) -> None:
"""Write transfer in which the server re-requests an earlier offset."""
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4,
offset=8,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(
transfer_pb2.Chunk(
transfer_id=4,
offset=4, # rewind
pending_bytes=8,
max_chunk_size_bytes=8), ),
(
transfer_pb2.Chunk(
transfer_id=4,
offset=12,
pending_bytes=16, # update max size
max_chunk_size_bytes=16), ),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
manager.write(4, b'pigweed data transfer')
self.assertEqual(len(self._sent_chunks), 5)
self.assertEqual(self._sent_chunks[1].data, b'pigweed ')
self.assertEqual(self._sent_chunks[2].data, b'data tra')
self.assertEqual(self._sent_chunks[3].data, b'eed data')
self.assertEqual(self._sent_chunks[4].data, b' transfer')
def test_write_transfer_bad_offset(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
(
(transfer_pb2.Chunk(transfer_id=4,
offset=0,
pending_bytes=8,
max_chunk_size_bytes=8), ),
(
transfer_pb2.Chunk(
transfer_id=4,
offset=100, # larger offset than data
pending_bytes=8,
max_chunk_size_bytes=8), ),
(transfer_pb2.Chunk(transfer_id=4, status=Status.OK.value), ),
),
)
with self.assertRaises(pw_transfer.Error) as context:
manager.write(4, b'small data')
exception = context.exception
self.assertEqual(exception.session_id, 4)
self.assertEqual(exception.status, Status.OUT_OF_RANGE)
def test_write_transfer_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
((transfer_pb2.Chunk(transfer_id=21,
status=Status.UNAVAILABLE.value), ), ),
)
with self.assertRaises(pw_transfer.Error) as context:
manager.write(21, b'no write')
exception = context.exception
self.assertEqual(exception.session_id, 21)
self.assertEqual(exception.status, Status.UNAVAILABLE)
def test_write_transfer_server_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_error(_Method.WRITE, Status.NOT_FOUND)
with self.assertRaises(pw_transfer.Error) as context:
manager.write(21, b'server error')
exception = context.exception
self.assertEqual(exception.session_id, 21)
self.assertEqual(exception.status, Status.INTERNAL)
def test_write_transfer_timeout_after_initial_chunk(self) -> None:
manager = pw_transfer.Manager(self._service,
default_response_timeout_s=0.001,
max_retries=2)
with self.assertRaises(pw_transfer.Error) as context:
manager.write(22, b'no server response!')
self.assertEqual(
self._sent_chunks,
[
transfer_pb2.Chunk(
transfer_id=22,
resource_id=22,
type=transfer_pb2.Chunk.Type.START), # initial chunk
transfer_pb2.Chunk(
transfer_id=22,
resource_id=22,
type=transfer_pb2.Chunk.Type.START), # retry 1
transfer_pb2.Chunk(
transfer_id=22,
resource_id=22,
type=transfer_pb2.Chunk.Type.START), # retry 2
])
exception = context.exception
self.assertEqual(exception.session_id, 22)
self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
def test_write_transfer_timeout_after_intermediate_chunk(self) -> None:
"""Tests write transfers that timeout after the initial chunk."""
manager = pw_transfer.Manager(
self._service,
default_response_timeout_s=DEFAULT_TIMEOUT_S,
max_retries=2)
self._enqueue_server_responses(_Method.WRITE, [[
transfer_pb2.Chunk(
transfer_id=22, pending_bytes=10, max_chunk_size_bytes=5)
]])
with self.assertRaises(pw_transfer.Error) as context:
manager.write(22, b'0123456789')
last_data_chunk = transfer_pb2.Chunk(transfer_id=22,
data=b'56789',
offset=5,
remaining_bytes=0,
type=transfer_pb2.Chunk.Type.DATA)
self.assertEqual(
self._sent_chunks,
[
transfer_pb2.Chunk(transfer_id=22,
resource_id=22,
type=transfer_pb2.Chunk.Type.START),
transfer_pb2.Chunk(transfer_id=22,
data=b'01234',
type=transfer_pb2.Chunk.Type.DATA),
last_data_chunk, # last chunk
last_data_chunk, # retry 1
last_data_chunk, # retry 2
])
exception = context.exception
self.assertEqual(exception.session_id, 22)
self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
def test_write_zero_pending_bytes_is_internal_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
self._enqueue_server_responses(
_Method.WRITE,
((transfer_pb2.Chunk(transfer_id=23, pending_bytes=0), ), ),
)
with self.assertRaises(pw_transfer.Error) as context:
manager.write(23, b'no write')
exception = context.exception
self.assertEqual(exception.session_id, 23)
self.assertEqual(exception.status, Status.INTERNAL)
class ProgressStatsTest(unittest.TestCase):
def test_received_percent_known_total(self) -> None:
self.assertEqual(
pw_transfer.ProgressStats(75, 0, 100).percent_received(), 0.0)
self.assertEqual(
pw_transfer.ProgressStats(75, 50, 100).percent_received(), 50.0)
self.assertEqual(
pw_transfer.ProgressStats(100, 100, 100).percent_received(), 100.0)
def test_received_percent_unknown_total(self) -> None:
self.assertTrue(
math.isnan(
pw_transfer.ProgressStats(75, 50, None).percent_received()))
self.assertTrue(
math.isnan(
pw_transfer.ProgressStats(100, 100, None).percent_received()))
def test_str_known_total(self) -> None:
stats = str(pw_transfer.ProgressStats(75, 50, 100))
self.assertIn('75', stats)
self.assertIn('50', stats)
self.assertIn('100', stats)
def test_str_unknown_total(self) -> None:
stats = str(pw_transfer.ProgressStats(75, 50, None))
self.assertIn('75', stats)
self.assertIn('50', stats)
self.assertIn('unknown', stats)
if __name__ == '__main__':
unittest.main()