pw_build: Python code for generating unit tests
Move code for generating Python and C++ unit tests from pw_hdlc_lite to
pw_build so it can be reused.
Change-Id: Ie2573811a13bded511d1f195928c09820f9a3cdf
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/20600
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_build/py/BUILD.gn b/pw_build/py/BUILD.gn
new file mode 100644
index 0000000..bf28ba3
--- /dev/null
+++ b/pw_build/py/BUILD.gn
@@ -0,0 +1,23 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+# gn-format disable
+import("//build_overrides/pigweed.gni")
+
+import("$dir_pw_build/input_group.gni")
+
+# TODO(hepler): Replace this with a pw_python_package.
+pw_input_group("py") {
+ inputs = [ "pw_build/generated_tests.py" ]
+}
diff --git a/pw_build/py/pw_build/generated_tests.py b/pw_build/py/pw_build/generated_tests.py
new file mode 100644
index 0000000..b3c280a
--- /dev/null
+++ b/pw_build/py/pw_build/generated_tests.py
@@ -0,0 +1,187 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tools for generating Pigweed tests that execute in C++ and Python."""
+
+import argparse
+from dataclasses import dataclass
+from datetime import datetime
+from collections import defaultdict
+import unittest
+
+from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
+ Sequence, TextIO, TypeVar, Union)
+
+_CPP_HEADER = f"""\
+// Copyright {datetime.now().year} The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+// AUTOGENERATED - DO NOT EDIT
+//
+// Generated at {datetime.now().isoformat()}
+
+// clang-format off
+"""
+
+
+class Error(Exception):
+ """Something went wrong when generating tests."""
+
+
+T = TypeVar('T')
+
+
+@dataclass
+class Context(Generic[T]):
+ """Info passed into test generator functions for each test case."""
+ group: str
+ count: int
+ total: int
+ test_case: T
+
+ def cc_name(self) -> str:
+ name = ''.join(w.capitalize()
+ for w in self.group.replace('-', ' ').split(' '))
+ name = ''.join(c if c.isalnum() else '_' for c in name)
+ return f'{name}_{self.count}' if self.total > 1 else name
+
+ def py_name(self) -> str:
+ name = 'test_' + ''.join(c if c.isalnum() else '_'
+ for c in self.group.lower())
+ return f'{name}_{self.count}' if self.total > 1 else name
+
+
+# Test cases are specified as a sequence of strings or test case instances. The
+# strings are used to separate the tests into named groups. For example:
+#
+# STR_SPLIT_TEST_CASES = (
+# 'Empty input',
+# MyTestCase('', '', []),
+# MyTestCase('', 'foo', []),
+# 'Split on single character',
+# MyTestCase('abcde', 'c', ['ab', 'de']),
+# ...
+# )
+#
+GroupOrTest = Union[str, T]
+
+# Python tests are generated by a function that returns a function usable as a
+# unittest.TestCase method.
+PyTest = Callable[[unittest.TestCase], None]
+PyTestGenerator = Callable[[Context[T]], PyTest]
+
+# C++ tests are generated with a function that returns or yields lines of C++
+# code for the given test case.
+CcTestGenerator = Callable[[Context[T]], Iterable[str]]
+
+
+class TestGenerator(Generic[T]):
+ """Generates tests for multiple languages from a series of test cases."""
+ def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
+ self._cases: Dict[str, List[T]] = defaultdict(list)
+ message = ''
+
+ if len(test_cases) < 2:
+ raise Error('At least one test case must be provided')
+
+ if not isinstance(test_cases[0], str):
+ raise Error(
+ 'The first item in the test cases must be a group name string')
+
+ for case in test_cases:
+ if isinstance(case, str):
+ message = case
+ else:
+ self._cases[message].append(case)
+
+ if '' in self._cases:
+ raise Error('Empty test group names are not permitted')
+
+ def _test_contexts(self) -> Iterator[Context[T]]:
+ for group, test_list in self._cases.items():
+ for i, test_case in enumerate(test_list, 1):
+ yield Context(group, i, len(test_list), test_case)
+
+ def _generate_python_tests(self, define_py_test: PyTestGenerator):
+ tests: Dict[str, Callable[[Any], None]] = {}
+
+ for ctx in self._test_contexts():
+ test = define_py_test(ctx)
+ test.__name__ = ctx.py_name()
+
+ if test.__name__ in tests:
+ raise Error(
+ f'Multiple Python tests are named {test.__name__}!')
+
+ tests[test.__name__] = test
+
+ return tests
+
+ def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
+ """Returns a Python unittest.TestCase class with tests for each case."""
+ return type(name, (unittest.TestCase, ),
+ self._generate_python_tests(define_py_test))
+
+ def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
+ footer: str) -> Iterator[str]:
+ yield _CPP_HEADER
+ yield header
+
+ for ctx in self._test_contexts():
+ yield from define_cpp_test(ctx)
+ yield ''
+
+ yield footer
+
+ def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
+ header: str, footer: str):
+ """Writes C++ unit tests for each test case to the given file."""
+ for line in self._generate_cc_tests(define_cpp_test, header, footer):
+ output.write(line)
+ output.write('\n')
+
+
+def _to_chars(data: bytes) -> Iterator[str]:
+ for i, byte in enumerate(data):
+ try:
+ char = data[i:i + 1].decode()
+ yield char if char.isprintable() else fr'\x{byte:02x}'
+ except UnicodeDecodeError:
+ yield fr'\x{byte:02x}'
+
+
+def cc_string(data: Union[str, bytes]) -> str:
+ """Returns a C++ string literal version of a byte string or UTF-8 string."""
+ if isinstance(data, str):
+ data = data.encode()
+
+ return '"' + ''.join(_to_chars(data)) + '"'
+
+
+def parse_test_generation_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description='Generate unit test files')
+ parser.add_argument('--generate-cc-test',
+ type=argparse.FileType('w'),
+ help='Generate the C++ test file')
+ return parser.parse_known_args()[0]
diff --git a/pw_hdlc_lite/BUILD.gn b/pw_hdlc_lite/BUILD.gn
index 5bcf447..75eb537 100644
--- a/pw_hdlc_lite/BUILD.gn
+++ b/pw_hdlc_lite/BUILD.gn
@@ -103,7 +103,8 @@
action("generate_decoder_test") {
outputs = [ "$target_gen_dir/generated_decoder_test.cc" ]
script = "py/decode_test.py"
- args = [ "--generate-cc-decode-test" ] + rebase_path(outputs)
+ args = [ "--generate-cc-test" ] + rebase_path(outputs)
+ deps = [ "$dir_pw_build/py" ]
}
pw_test("decoder_test") {
diff --git a/pw_hdlc_lite/py/decode_test.py b/pw_hdlc_lite/py/decode_test.py
index fbd1279..55c3f35 100755
--- a/pw_hdlc_lite/py/decode_test.py
+++ b/pw_hdlc_lite/py/decode_test.py
@@ -14,13 +14,11 @@
# the License.
"""Contains the Python decoder tests and generates C++ decoder tests."""
-from collections import defaultdict
-from pathlib import Path
-from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Union
-from typing import Optional, TextIO, Tuple
+from typing import Iterator, List, NamedTuple, Tuple
import unittest
-import sys
+from pw_build.generated_tests import Context, PyTest, TestGenerator, GroupOrTest
+from pw_build.generated_tests import parse_test_generation_args
from pw_hdlc_lite.decode import Frame, FrameDecoder, FrameStatus, NO_ADDRESS
from pw_hdlc_lite.protocol import frame_check_sequence as fcs
@@ -54,10 +52,7 @@
],
)
-TestCase = Tuple[bytes, List[Expected]]
-TestCases = Tuple[Union[str, TestCase], ...]
-
-TEST_CASES: TestCases = (
+TEST_CASES: Tuple[GroupOrTest[Tuple[bytes, List[Expected]]], ...] = (
'Empty payload',
(_encode(0, 0, b''), [Expected(0, b'\0', b'')]),
(_encode(55, 0x99, b''), [Expected(55, b'\x99', b'')]),
@@ -158,18 +153,7 @@
) # yapf: disable
# Formatting for the above tuple is very slow, so disable yapf.
-
-def _sort_test_cases(test_cases: TestCases) -> Dict[str, List[TestCase]]:
- cases: Dict[str, List[TestCase]] = defaultdict(list)
- message = ''
-
- for case in test_cases:
- if isinstance(case, str):
- message = case
- else:
- cases[message].append(case)
-
- return cases
+_TESTS = TestGenerator(TEST_CASES)
def _expected(frames: List[Frame]) -> Iterator[str]:
@@ -180,23 +164,7 @@
yield f' Status::DATA_LOSS, // Frame {i}'
-_CPP_HEADER = f"""\
-// Copyright 2020 The Pigweed Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not
-// use this file except in compliance with the License. You may obtain a copy of
-// the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-// License for the specific language governing permissions and limitations under
-// the License.
-
-// AUTOGENERATED by {Path(__file__).name} - DO NOT EDIT
-
+_CPP_HEADER = """\
#include "pw_hdlc_lite/decoder.h"
#include <array>
@@ -206,30 +174,22 @@
#include "gtest/gtest.h"
#include "pw_bytes/array.h"
-namespace pw::hdlc_lite {{
-namespace {{
-
-// clang-format off
+namespace pw::hdlc_lite {
+namespace {
"""
_CPP_FOOTER = """\
-// clang-format on
-
} // namespace
-} // namespace pw::hdlc_lite
-"""
+} // namespace pw::hdlc_lite"""
-def _cpp_test(group: str, count: Optional[int], data: bytes) -> Iterator[str]:
+def _cpp_test(ctx: Context) -> Iterator[str]:
"""Generates a C++ test for the provided test data."""
+ data, _ = ctx.test_case
frames = list(FrameDecoder().process(data))
data_bytes = ''.join(rf'\x{byte:02x}' for byte in data)
- name = ''.join(w.capitalize() for w in group.replace('-', ' ').split(' '))
- name = ''.join(c if c.isalnum() else '_' for c in name)
- name = name if count is None else name + f'_{count}'
-
- yield f'TEST(Decoder, {name}) {{'
+ yield f'TEST(Decoder, {ctx.cc_name()}) {{'
yield f' static constexpr auto kData = bytes::String("{data_bytes}");\n'
for i, frame in enumerate(frames, 1):
@@ -279,31 +239,14 @@
}}"""
-def _define_cc_tests(test_cases: TestCases, output: TextIO) -> None:
- """Writes C++ tests for all test cases."""
- output.write(_CPP_HEADER)
+def _define_py_test(ctx: Context) -> PyTest:
+ data, expected_frames = ctx.test_case
- for group, test_list in _sort_test_cases(test_cases).items():
- for i, (data, _) in enumerate(test_list, 1):
- count = i if len(test_list) > 1 else None
- for line in _cpp_test(group, count, data):
- output.write(line)
- output.write('\n')
-
- output.write('\n')
-
- output.write(_CPP_FOOTER)
-
-
-def _define_py_test(group: str,
- data: bytes,
- expected_frames: List[Expected],
- count: int = None) -> Callable[[Any], None]:
def test(self) -> None:
# Decode in one call
self.assertEqual(expected_frames,
list(FrameDecoder().process(data)),
- msg=f'{group}: {data!r}')
+ msg=f'{ctx.group}: {data!r}')
# Decode byte-by-byte
decoder = FrameDecoder()
@@ -313,35 +256,18 @@
self.assertEqual(expected_frames,
decoded_frames,
- msg=f'{group} (byte-by-byte): {data!r}')
+ msg=f'{ctx.group} (byte-by-byte): {data!r}')
- name = 'test_' + ''.join(c if c.isalnum() else '_' for c in group.lower())
- test.__name__ = name if count is None else name + f'_{count}'
return test
-def _define_py_tests(test_cases: TestCases = TEST_CASES):
- """Generates a Python test function for each test case."""
- tests: Dict[str, Callable[[Any], None]] = {}
-
- for group, test_list in _sort_test_cases(test_cases).items():
- for i, (data, expected_frames) in enumerate(test_list, 1):
- count = i if len(test_list) > 1 else None
- test = _define_py_test(group, data, expected_frames, count)
-
- assert test.__name__ not in tests, f'Duplicate! {test.__name__}'
- tests[test.__name__] = test
-
- return tests
-
-
# Class that tests all cases in TEST_CASES.
-DecoderTest = type('DecoderTest', (unittest.TestCase, ), _define_py_tests())
+DecoderTest = _TESTS.python_tests('DecoderTest', _define_py_test)
if __name__ == '__main__':
- # If --generate-cc-decode-test is provided, generate the C++ test file.
- if len(sys.argv) >= 2 and sys.argv[1] == '--generate-cc-decode-test':
- with Path(sys.argv[2]).open('w') as file:
- _define_cc_tests(TEST_CASES, file)
- else: # Otherwise, run the unit tests.
+ args = parse_test_generation_args()
+ if args.generate_cc_test:
+ _TESTS.cc_tests(args.generate_cc_test, _cpp_test, _CPP_HEADER,
+ _CPP_FOOTER)
+ else:
unittest.main()
diff --git a/pw_hdlc_lite/py/setup.py b/pw_hdlc_lite/py/setup.py
index 246c2fd..dad879d 100644
--- a/pw_hdlc_lite/py/setup.py
+++ b/pw_hdlc_lite/py/setup.py
@@ -23,4 +23,5 @@
description='Tools for Encoding/Decoding data using the HDLC-Lite protocol',
packages=setuptools.find_packages(),
install_requires=['ipython'],
+ tests_require=['pw_build'],
)