blob: a7b297c7705e4448f99eb525878f902464d1b24e [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Test fixture for pw_transfer integration tests."""
import argparse
import asyncio
from dataclasses import dataclass
import logging
import pathlib
from pathlib import Path
import sys
import tempfile
from typing import BinaryIO, Iterable, List, NamedTuple, Optional
import unittest
from google.protobuf import text_format
from pigweed.pw_protobuf.pw_protobuf_protos import status_pb2
from pigweed.pw_transfer.integration_test import config_pb2
from rules_python.python.runfiles import runfiles
_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
_LOG.level = logging.DEBUG
_LOG.addHandler(logging.StreamHandler(sys.stdout))
class LogMonitor:
"""Monitors lines read from the reader, and logs them."""
class Error(Exception):
"""Raised if wait_for_line reaches EOF before expected line."""
pass
def __init__(self, prefix: str, reader: asyncio.StreamReader):
"""Initializer.
Args:
prefix: Prepended to read lines before they are logged.
reader: StreamReader to read lines from.
"""
self._prefix = prefix
self._reader = reader
# Queue of messages waiting to be monitored.
self._queue = asyncio.Queue()
# Relog any messages read from the reader, and enqueue them for
# monitoring.
self._relog_and_enqueue_task = asyncio.create_task(
self._relog_and_enqueue()
)
async def wait_for_line(self, msg: str):
"""Wait for a line containing msg to be read from the reader."""
while True:
line = await self._queue.get()
if not line:
raise LogMonitor.Error(
f"Reached EOF before getting line matching {msg}"
)
if msg in line.decode():
return
async def wait_for_eof(self):
"""Wait for the reader to reach EOF, relogging any lines read."""
# Drain the queue, since we're not monitoring it any more.
drain_queue = asyncio.create_task(self._drain_queue())
await asyncio.gather(drain_queue, self._relog_and_enqueue_task)
async def _relog_and_enqueue(self):
"""Reads lines from the reader, logs them, and puts them in queue."""
while True:
line = await self._reader.readline()
await self._queue.put(line)
if line:
_LOG.info(f"{self._prefix} {line.decode().rstrip()}")
else:
# EOF. Note, we still put the EOF in the queue, so that the
# queue reader can process it appropriately.
return
async def _drain_queue(self):
while True:
line = await self._queue.get()
if not line:
# EOF.
return
class MonitoredSubprocess:
"""A subprocess with monitored asynchronous communication."""
@staticmethod
async def create(cmd: List[str], prefix: str, stdinput: bytes):
"""Starts the subprocess and writes stdinput to stdin.
This method returns once stdinput has been written to stdin. The
MonitoredSubprocess continues to log the process's stderr and stdout
(with the prefix) until it terminates.
Args:
cmd: Command line to execute.
prefix: Prepended to process logs.
stdinput: Written to stdin on process startup.
"""
self = MonitoredSubprocess()
self._process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._stderr_monitor = LogMonitor(
f"{prefix} ERR:", self._process.stderr
)
self._stdout_monitor = LogMonitor(
f"{prefix} OUT:", self._process.stdout
)
self._process.stdin.write(stdinput)
await self._process.stdin.drain()
self._process.stdin.close()
await self._process.stdin.wait_closed()
return self
async def wait_for_line(self, stream: str, msg: str, timeout: float):
"""Wait for a line containing msg to be read on the stream."""
if stream == "stdout":
monitor = self._stdout_monitor
elif stream == "stderr":
monitor = self._stderr_monitor
else:
raise ValueError(
"Stream must be 'stdout' or 'stderr', got {stream}"
)
await asyncio.wait_for(monitor.wait_for_line(msg), timeout)
def returncode(self):
return self._process.returncode
def terminate(self):
"""Terminate the process."""
self._process.terminate()
async def wait_for_termination(self, timeout: float):
"""Wait for the process to terminate."""
await asyncio.wait_for(
asyncio.gather(
self._process.wait(),
self._stdout_monitor.wait_for_eof(),
self._stderr_monitor.wait_for_eof(),
),
timeout,
)
async def terminate_and_wait(self, timeout: float):
"""Terminate the process and wait for it to exit."""
if self.returncode() is not None:
# Process already terminated
return
self.terminate()
await self.wait_for_termination(timeout)
class TransferConfig(NamedTuple):
"""A simple tuple to collect configs for test binaries."""
server: config_pb2.ServerConfig
client: config_pb2.ClientConfig
proxy: config_pb2.ProxyConfig
class TransferIntegrationTestHarness:
"""A class to manage transfer integration tests"""
# Prefix for log messages coming from the harness (as opposed to the server,
# client, or proxy processes). Padded so that the length is the same as
# "SERVER OUT:".
_PREFIX = "HARNESS: "
@dataclass
class Config:
server_port: int = 3300
client_port: int = 3301
java_client_binary: Optional[Path] = None
cpp_client_binary: Optional[Path] = None
python_client_binary: Optional[Path] = None
proxy_binary: Optional[Path] = None
server_binary: Optional[Path] = None
class TransferExitCodes(NamedTuple):
client: int
server: int
def __init__(self, harness_config: Config) -> None:
# TODO(tpudlik): This is Bazel-only. Support gn, too.
r = runfiles.Create()
# Set defaults.
self._JAVA_CLIENT_BINARY = r.Rlocation(
"pigweed/pw_transfer/integration_test/java_client"
)
self._CPP_CLIENT_BINARY = r.Rlocation(
"pigweed/pw_transfer/integration_test/cpp_client"
)
self._PYTHON_CLIENT_BINARY = r.Rlocation(
"pigweed/pw_transfer/integration_test/python_client"
)
self._PROXY_BINARY = r.Rlocation(
"pigweed/pw_transfer/integration_test/proxy"
)
self._SERVER_BINARY = r.Rlocation(
"pigweed/pw_transfer/integration_test/server"
)
# Server/client ports are non-optional, so use those.
self._CLIENT_PORT = harness_config.client_port
self._SERVER_PORT = harness_config.server_port
# If the harness configuration specifies overrides, use those.
if harness_config.java_client_binary is not None:
self._JAVA_CLIENT_BINARY = harness_config.java_client_binary
if harness_config.cpp_client_binary is not None:
self._CPP_CLIENT_BINARY = harness_config.cpp_client_binary
if harness_config.python_client_binary is not None:
self._PYTHON_CLIENT_BINARY = harness_config.python_client_binary
if harness_config.proxy_binary is not None:
self._PROXY_BINARY = harness_config.proxy_binary
if harness_config.server_binary is not None:
self._SERVER_BINARY = harness_config.server_binary
self._CLIENT_BINARY = {
"cpp": self._CPP_CLIENT_BINARY,
"java": self._JAVA_CLIENT_BINARY,
"python": self._PYTHON_CLIENT_BINARY,
}
pass
async def _start_client(
self, client_type: str, config: config_pb2.ClientConfig
):
_LOG.info(f"{self._PREFIX} Starting client with config\n{config}")
self._client = await MonitoredSubprocess.create(
[self._CLIENT_BINARY[client_type], str(self._CLIENT_PORT)],
"CLIENT",
str(config).encode('ascii'),
)
async def _start_server(self, config: config_pb2.ServerConfig):
_LOG.info(f"{self._PREFIX} Starting server with config\n{config}")
self._server = await MonitoredSubprocess.create(
[self._SERVER_BINARY, str(self._SERVER_PORT)],
"SERVER",
str(config).encode('ascii'),
)
async def _start_proxy(self, config: config_pb2.ProxyConfig):
_LOG.info(f"{self._PREFIX} Starting proxy with config\n{config}")
self._proxy = await MonitoredSubprocess.create(
[
self._PROXY_BINARY,
"--server-port",
str(self._SERVER_PORT),
"--client-port",
str(self._CLIENT_PORT),
],
# Extra space in "PROXY " so that it lines up with "SERVER".
"PROXY ",
str(config).encode('ascii'),
)
async def perform_transfers(
self,
server_config: config_pb2.ServerConfig,
client_type: str,
client_config: config_pb2.ClientConfig,
proxy_config: config_pb2.ProxyConfig,
) -> TransferExitCodes:
"""Performs a pw_transfer write.
Args:
server_config: Server configuration.
client_type: Either "cpp", "java", or "python".
client_config: Client configuration.
proxy_config: Proxy configuration.
Returns:
Exit code of the client and server as a tuple.
"""
# Timeout for components (server, proxy) to come up or shut down after
# write is finished or a signal is sent. Approximately arbitrary. Should
# not be too long so that we catch bugs in the server that prevent it
# from shutting down.
TIMEOUT = 5 # seconds
try:
await self._start_proxy(proxy_config)
await self._proxy.wait_for_line(
"stderr", "Listening for client connection", TIMEOUT
)
await self._start_server(server_config)
await self._server.wait_for_line(
"stderr", "Starting pw_rpc server on port", TIMEOUT
)
await self._start_client(client_type, client_config)
# No timeout: the client will only exit once the transfer
# completes, and this can take a long time for large payloads.
await self._client.wait_for_termination(None)
# Wait for the server to exit.
await self._server.wait_for_termination(TIMEOUT)
finally:
# Stop the server, if still running. (Only expected if the
# wait_for above timed out.)
if self._server:
await self._server.terminate_and_wait(TIMEOUT)
# Stop the proxy. Unlike the server, we expect it to still be
# running at this stage.
if self._proxy:
await self._proxy.terminate_and_wait(TIMEOUT)
return self.TransferExitCodes(
self._client.returncode(), self._server.returncode()
)
class BasicTransfer(NamedTuple):
id: int
type: config_pb2.TransferAction.TransferType.ValueType
data: bytes
class TransferIntegrationTest(unittest.TestCase):
"""A base class for transfer integration tests.
This significantly reduces the boiler plate required for building
integration test cases for pw_transfer. This class does not include any
tests itself, but instead bundles together much of the boiler plate required
for making an integration test for pw_transfer using this test fixture.
"""
HARNESS_CONFIG = TransferIntegrationTestHarness.Config()
@classmethod
def setUpClass(cls):
cls.harness = TransferIntegrationTestHarness(cls.HARNESS_CONFIG)
@staticmethod
def default_server_config() -> config_pb2.ServerConfig:
return config_pb2.ServerConfig(
chunk_size_bytes=216,
pending_bytes=32 * 1024,
chunk_timeout_seconds=5,
transfer_service_retries=4,
extend_window_divisor=32,
)
@staticmethod
def default_client_config() -> config_pb2.ClientConfig:
return config_pb2.ClientConfig(
max_retries=5,
max_lifetime_retries=1500,
initial_chunk_timeout_ms=4000,
chunk_timeout_ms=4000,
)
@staticmethod
def default_proxy_config() -> config_pb2.ProxyConfig:
return text_format.Parse(
"""
client_filter_stack: [
{ hdlc_packetizer: {} },
{ data_dropper: {rate: 0.01, seed: 1649963713563718435} }
]
server_filter_stack: [
{ hdlc_packetizer: {} },
{ data_dropper: {rate: 0.01, seed: 1649963713563718436} }
]""",
config_pb2.ProxyConfig(),
)
@staticmethod
def default_config() -> TransferConfig:
"""Returns a new transfer config with default options."""
return TransferConfig(
TransferIntegrationTest.default_server_config(),
TransferIntegrationTest.default_client_config(),
TransferIntegrationTest.default_proxy_config(),
)
def do_single_write(
self,
client_type: str,
config: TransferConfig,
resource_id: int,
data: bytes,
protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
permanent_resource_id=False,
expected_status=status_pb2.StatusCode.OK,
) -> None:
"""Performs a single client-to-server write of the provided data."""
with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_server_output:
if permanent_resource_id:
config.server.resources[
resource_id
].default_destination_path = f_server_output.name
else:
config.server.resources[resource_id].destination_paths.append(
f_server_output.name
)
config.client.transfer_actions.append(
config_pb2.TransferAction(
resource_id=resource_id,
file_path=f_payload.name,
transfer_type=config_pb2.TransferAction.TransferType.WRITE_TO_SERVER,
protocol_version=protocol_version,
expected_status=int(expected_status),
)
)
f_payload.write(data)
f_payload.flush() # Ensure contents are there to read!
exit_codes = asyncio.run(
self.harness.perform_transfers(
config.server, client_type, config.client, config.proxy
)
)
self.assertEqual(exit_codes.client, 0)
self.assertEqual(exit_codes.server, 0)
if expected_status == status_pb2.StatusCode.OK:
self.assertEqual(f_server_output.read(), data)
def do_single_read(
self,
client_type: str,
config: TransferConfig,
resource_id: int,
data: bytes,
protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
permanent_resource_id=False,
expected_status=status_pb2.StatusCode.OK,
) -> None:
"""Performs a single server-to-client read of the provided data."""
with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_client_output:
if permanent_resource_id:
config.server.resources[
resource_id
].default_source_path = f_payload.name
else:
config.server.resources[resource_id].source_paths.append(
f_payload.name
)
config.client.transfer_actions.append(
config_pb2.TransferAction(
resource_id=resource_id,
file_path=f_client_output.name,
transfer_type=config_pb2.TransferAction.TransferType.READ_FROM_SERVER,
protocol_version=protocol_version,
expected_status=int(expected_status),
)
)
f_payload.write(data)
f_payload.flush() # Ensure contents are there to read!
exit_codes = asyncio.run(
self.harness.perform_transfers(
config.server, client_type, config.client, config.proxy
)
)
self.assertEqual(exit_codes.client, 0)
self.assertEqual(exit_codes.server, 0)
if expected_status == status_pb2.StatusCode.OK:
self.assertEqual(f_client_output.read(), data)
def do_basic_transfer_sequence(
self,
client_type: str,
config: TransferConfig,
transfers: Iterable[BasicTransfer],
) -> None:
"""Performs multiple reads/writes in a single client/server session."""
class ReadbackSet(NamedTuple):
server_file: BinaryIO
client_file: BinaryIO
expected_data: bytes
transfer_results: List[ReadbackSet] = []
for transfer in transfers:
server_file = tempfile.NamedTemporaryFile()
client_file = tempfile.NamedTemporaryFile()
if (
transfer.type
== config_pb2.TransferAction.TransferType.READ_FROM_SERVER
):
server_file.write(transfer.data)
server_file.flush()
config.server.resources[transfer.id].source_paths.append(
server_file.name
)
elif (
transfer.type
== config_pb2.TransferAction.TransferType.WRITE_TO_SERVER
):
client_file.write(transfer.data)
client_file.flush()
config.server.resources[transfer.id].destination_paths.append(
server_file.name
)
else:
raise ValueError('Unknown TransferType')
config.client.transfer_actions.append(
config_pb2.TransferAction(
resource_id=transfer.id,
file_path=client_file.name,
transfer_type=transfer.type,
)
)
transfer_results.append(
ReadbackSet(server_file, client_file, transfer.data)
)
exit_codes = asyncio.run(
self.harness.perform_transfers(
config.server, client_type, config.client, config.proxy
)
)
for i, result in enumerate(transfer_results):
with self.subTest(i=i):
# Need to seek to the beginning of the file to read written
# data.
result.client_file.seek(0, 0)
result.server_file.seek(0, 0)
self.assertEqual(
result.client_file.read(), result.expected_data
)
self.assertEqual(
result.server_file.read(), result.expected_data
)
# Check exit codes at the end as they provide less useful info.
self.assertEqual(exit_codes.client, 0)
self.assertEqual(exit_codes.server, 0)
def run_tests_for(test_class_name):
parser = argparse.ArgumentParser()
parser.add_argument(
'--server-port',
type=int,
help='Port of the integration test server. The proxy will forward connections to this port',
)
parser.add_argument(
'--client-port',
type=int,
help='Port on which to listen for connections from integration test client.',
)
parser.add_argument(
'--java-client-binary',
type=pathlib.Path,
default=None,
help='Path to the Java transfer client to use in tests',
)
parser.add_argument(
'--cpp-client-binary',
type=pathlib.Path,
default=None,
help='Path to the C++ transfer client to use in tests',
)
parser.add_argument(
'--python-client-binary',
type=pathlib.Path,
default=None,
help='Path to the Python transfer client to use in tests',
)
parser.add_argument(
'--server-binary',
type=pathlib.Path,
default=None,
help='Path to the transfer server to use in tests',
)
parser.add_argument(
'--proxy-binary',
type=pathlib.Path,
default=None,
help=(
'Path to the proxy binary to use in tests to allow interception '
'of client/server data'
),
)
(args, passthrough_args) = parser.parse_known_args()
# Inherrit the default configuration from the class being tested, and only
# override provided arguments.
for arg in vars(args):
val = getattr(args, arg)
if val:
setattr(test_class_name.HARNESS_CONFIG, arg, val)
unittest_args = [sys.argv[0]] + passthrough_args
unittest.main(argv=unittest_args)