blob: 4386bce3e59afd5a5b64a7f12426203d1ac22ef8 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests transfers between the Python client and C++ service."""
from pathlib import Path
import random
import tempfile
from typing import List, Tuple, Union
import unittest
from pw_hdlc import rpc
from pw_rpc import testing
from pw_status import Status
import pw_transfer
from pw_transfer import transfer_pb2
from pw_transfer_test import test_server_pb2
ITERATIONS = 5
TIMEOUT_S = 0.05
_DATA_4096B = b'SPAM' * (4096 // len('SPAM'))
class TransferServiceIntegrationTest(unittest.TestCase):
"""Tests transfers between the Python client and C++ service."""
test_server_command: Tuple[str, ...] = ()
port: int
def setUp(self) -> None:
self._tempdir = tempfile.TemporaryDirectory(
prefix=f'pw_transfer_{self.id().rsplit(".", 1)[-1]}_')
self.directory = Path(self._tempdir.name)
command = (*self.test_server_command, str(self.directory))
self._context = rpc.HdlcRpcLocalServerAndClient(
command,
self.port, [transfer_pb2, test_server_pb2],
for_testing=True)
service = self._context.client.channel(1).rpcs.pw.transfer.Transfer
self.manager = pw_transfer.Manager(
service, default_response_timeout_s=TIMEOUT_S)
self._test_server = self._context.client.channel(
1).rpcs.pw.transfer.TestServer
def tearDown(self) -> None:
try:
self._tempdir.cleanup()
finally:
if hasattr(self, '_context'):
self._context.close()
def transfer_file_path(self, transfer_id: int) -> Path:
return self.directory / str(transfer_id)
def set_content(self, transfer_id: int, data: Union[bytes, str]) -> None:
self.transfer_file_path(transfer_id).write_bytes(
data.encode() if isinstance(data, str) else data)
self._test_server.ReloadTransferFiles()
def get_content(self, transfer_id: int) -> bytes:
return self.transfer_file_path(transfer_id).read_bytes()
def test_read_unknown_id(self) -> None:
with self.assertRaises(pw_transfer.Error) as ctx:
self.manager.read(99)
self.assertEqual(ctx.exception.status, Status.NOT_FOUND)
def test_read_empty(self) -> None:
for _ in range(ITERATIONS):
self.set_content(24, '')
self.assertEqual(self.manager.read(24), b'')
def test_read_single_byte(self) -> None:
for _ in range(ITERATIONS):
self.set_content(25, '0')
self.assertEqual(self.manager.read(25), b'0')
def test_read_small_amount_of_data(self) -> None:
for _ in range(ITERATIONS):
self.set_content(26, 'hunter2')
self.assertEqual(self.manager.read(26), b'hunter2')
def test_read_large_amount_of_data(self) -> None:
for _ in range(ITERATIONS):
self.set_content(27, '~' * 512)
self.assertEqual(self.manager.read(27), b'~' * 512)
def test_write_unknown_id(self) -> None:
with self.assertRaises(pw_transfer.Error) as ctx:
self.manager.write(99, '')
self.assertEqual(ctx.exception.status, Status.NOT_FOUND)
def test_write_empty(self) -> None:
for _ in range(ITERATIONS):
self.set_content(28, 'junk')
self.manager.write(28, b'')
self.assertEqual(self.get_content(28), b'')
def test_write_single_byte(self) -> None:
for _ in range(ITERATIONS):
self.set_content(29, 'junk')
self.manager.write(29, b'$')
self.assertEqual(self.get_content(29), b'$')
def test_write_small_amount_of_data(self) -> None:
for _ in range(ITERATIONS):
self.set_content(30, 'junk')
self.manager.write(30, b'file transfer')
self.assertEqual(self.get_content(30), b'file transfer')
def test_write_large_amount_of_data(self) -> None:
for _ in range(ITERATIONS):
self.set_content(31, 'junk')
self.manager.write(31, b'*' * 512)
self.assertEqual(self.get_content(31), b'*' * 512)
def test_write_very_large_amount_of_data(self) -> None:
for _ in range(ITERATIONS):
self.set_content(32, 'junk')
# Larger than the transfer service's configured pending_bytes.
self.manager.write(32, _DATA_4096B)
self.assertEqual(self.get_content(32), _DATA_4096B)
def test_write_string(self) -> None:
for _ in range(ITERATIONS):
# Write a string instead of bytes.
self.set_content(33, 'junk')
self.manager.write(33, 'hello world')
self.assertEqual(self.get_content(33), b'hello world')
def test_write_drop_data_chunks_and_transfer_parameters(self) -> None:
self.set_content(34, 'junk')
# Allow the initial packet and first chunk, then drop the second chunk.
self._context.outgoing_packets.keep(2)
self._context.outgoing_packets.drop(1)
# Allow the initial transfer parameters updates, then drop the next two.
self._context.incoming_packets.keep(1)
self._context.incoming_packets.drop(2)
with self.assertLogs('pw_transfer', 'DEBUG') as logs:
self.manager.write(34, _DATA_4096B)
self.assertEqual(self.get_content(34), _DATA_4096B)
# Verify that the client retried twice.
messages = [r.getMessage() for r in logs.records]
retry = f'Received no responses for {TIMEOUT_S:.3f}s; retrying {{}}/3'
self.assertIn(retry.format(1), messages)
self.assertIn(retry.format(2), messages)
def test_write_regularly_drop_packets(self) -> None:
self.set_content(35, 'junk')
self._context.outgoing_packets.drop_every(5) # drop one per window
self._context.incoming_packets.drop_every(3)
self.manager.write(35, _DATA_4096B)
self.assertEqual(self.get_content(35), _DATA_4096B)
def test_write_randomly_drop_packets(self) -> None:
# Allow lots of retries since there are lots of drops.
self.manager.max_retries = 9
for seed in [1, 5678, 600613]:
self.set_content(seed, 'junk')
rand = random.Random(seed)
self._context.incoming_packets.randomly_drop(3, rand)
self._context.outgoing_packets.randomly_drop(3, rand)
data = bytes(
rand.randrange(256) for _ in range(rand.randrange(16384)))
self.manager.write(seed, data)
self.assertEqual(self.get_content(seed), data)
self._context.incoming_packets.reset()
self._context.outgoing_packets.reset()
def _main(test_server_command: List[str], port: int,
unittest_args: List[str]) -> None:
TransferServiceIntegrationTest.test_server_command = tuple(
test_server_command)
TransferServiceIntegrationTest.port = port
unittest.main(argv=unittest_args)
if __name__ == '__main__':
_main(**vars(testing.parse_test_server_args()))