pw_build: Python code for generating unit tests

Move code for generating Python and C++ unit tests from pw_hdlc_lite to
pw_build so it can be reused.

Change-Id: Ie2573811a13bded511d1f195928c09820f9a3cdf
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/20600
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_build/py/BUILD.gn b/pw_build/py/BUILD.gn
new file mode 100644
index 0000000..bf28ba3
--- /dev/null
+++ b/pw_build/py/BUILD.gn
@@ -0,0 +1,23 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+# gn-format disable
+import("//build_overrides/pigweed.gni")
+
+import("$dir_pw_build/input_group.gni")
+
+# TODO(hepler): Replace this with a pw_python_package.
+pw_input_group("py") {
+  inputs = [ "pw_build/generated_tests.py" ]
+}
diff --git a/pw_build/py/pw_build/generated_tests.py b/pw_build/py/pw_build/generated_tests.py
new file mode 100644
index 0000000..b3c280a
--- /dev/null
+++ b/pw_build/py/pw_build/generated_tests.py
@@ -0,0 +1,187 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tools for generating Pigweed tests that execute in C++ and Python."""
+
+import argparse
+from dataclasses import dataclass
+from datetime import datetime
+from collections import defaultdict
+import unittest
+
+from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
+                    Sequence, TextIO, TypeVar, Union)
+
+_CPP_HEADER = f"""\
+// Copyright {datetime.now().year} The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+// AUTOGENERATED - DO NOT EDIT
+//
+// Generated at {datetime.now().isoformat()}
+
+// clang-format off
+"""
+
+
+class Error(Exception):
+    """Something went wrong when generating tests."""
+
+
+T = TypeVar('T')
+
+
+@dataclass
+class Context(Generic[T]):
+    """Info passed into test generator functions for each test case."""
+    group: str
+    count: int
+    total: int
+    test_case: T
+
+    def cc_name(self) -> str:
+        name = ''.join(w.capitalize()
+                       for w in self.group.replace('-', ' ').split(' '))
+        name = ''.join(c if c.isalnum() else '_' for c in name)
+        return f'{name}_{self.count}' if self.total > 1 else name
+
+    def py_name(self) -> str:
+        name = 'test_' + ''.join(c if c.isalnum() else '_'
+                                 for c in self.group.lower())
+        return f'{name}_{self.count}' if self.total > 1 else name
+
+
+# Test cases are specified as a sequence of strings or test case instances. The
+# strings are used to separate the tests into named groups. For example:
+#
+#   STR_SPLIT_TEST_CASES = (
+#     'Empty input',
+#     MyTestCase('', '', []),
+#     MyTestCase('', 'foo', []),
+#     'Split on single character',
+#     MyTestCase('abcde', 'c', ['ab', 'de']),
+#     ...
+#   )
+#
+GroupOrTest = Union[str, T]
+
+# Python tests are generated by a function that returns a function usable as a
+# unittest.TestCase method.
+PyTest = Callable[[unittest.TestCase], None]
+PyTestGenerator = Callable[[Context[T]], PyTest]
+
+# C++ tests are generated with a function that returns or yields lines of C++
+# code for the given test case.
+CcTestGenerator = Callable[[Context[T]], Iterable[str]]
+
+
+class TestGenerator(Generic[T]):
+    """Generates tests for multiple languages from a series of test cases."""
+    def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
+        self._cases: Dict[str, List[T]] = defaultdict(list)
+        message = ''
+
+        if len(test_cases) < 2:
+            raise Error('At least one test case must be provided')
+
+        if not isinstance(test_cases[0], str):
+            raise Error(
+                'The first item in the test cases must be a group name string')
+
+        for case in test_cases:
+            if isinstance(case, str):
+                message = case
+            else:
+                self._cases[message].append(case)
+
+        if '' in self._cases:
+            raise Error('Empty test group names are not permitted')
+
+    def _test_contexts(self) -> Iterator[Context[T]]:
+        for group, test_list in self._cases.items():
+            for i, test_case in enumerate(test_list, 1):
+                yield Context(group, i, len(test_list), test_case)
+
+    def _generate_python_tests(self, define_py_test: PyTestGenerator):
+        tests: Dict[str, Callable[[Any], None]] = {}
+
+        for ctx in self._test_contexts():
+            test = define_py_test(ctx)
+            test.__name__ = ctx.py_name()
+
+            if test.__name__ in tests:
+                raise Error(
+                    f'Multiple Python tests are named {test.__name__}!')
+
+            tests[test.__name__] = test
+
+        return tests
+
+    def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
+        """Returns a Python unittest.TestCase class with tests for each case."""
+        return type(name, (unittest.TestCase, ),
+                    self._generate_python_tests(define_py_test))
+
+    def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
+                           footer: str) -> Iterator[str]:
+        yield _CPP_HEADER
+        yield header
+
+        for ctx in self._test_contexts():
+            yield from define_cpp_test(ctx)
+            yield ''
+
+        yield footer
+
+    def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
+                 header: str, footer: str):
+        """Writes C++ unit tests for each test case to the given file."""
+        for line in self._generate_cc_tests(define_cpp_test, header, footer):
+            output.write(line)
+            output.write('\n')
+
+
+def _to_chars(data: bytes) -> Iterator[str]:
+    for i, byte in enumerate(data):
+        try:
+            char = data[i:i + 1].decode()
+            yield char if char.isprintable() else fr'\x{byte:02x}'
+        except UnicodeDecodeError:
+            yield fr'\x{byte:02x}'
+
+
+def cc_string(data: Union[str, bytes]) -> str:
+    """Returns a C++ string literal version of a byte string or UTF-8 string."""
+    if isinstance(data, str):
+        data = data.encode()
+
+    return '"' + ''.join(_to_chars(data)) + '"'
+
+
+def parse_test_generation_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description='Generate unit test files')
+    parser.add_argument('--generate-cc-test',
+                        type=argparse.FileType('w'),
+                        help='Generate the C++ test file')
+    return parser.parse_known_args()[0]
diff --git a/pw_hdlc_lite/BUILD.gn b/pw_hdlc_lite/BUILD.gn
index 5bcf447..75eb537 100644
--- a/pw_hdlc_lite/BUILD.gn
+++ b/pw_hdlc_lite/BUILD.gn
@@ -103,7 +103,8 @@
 action("generate_decoder_test") {
   outputs = [ "$target_gen_dir/generated_decoder_test.cc" ]
   script = "py/decode_test.py"
-  args = [ "--generate-cc-decode-test" ] + rebase_path(outputs)
+  args = [ "--generate-cc-test" ] + rebase_path(outputs)
+  deps = [ "$dir_pw_build/py" ]
 }
 
 pw_test("decoder_test") {
diff --git a/pw_hdlc_lite/py/decode_test.py b/pw_hdlc_lite/py/decode_test.py
index fbd1279..55c3f35 100755
--- a/pw_hdlc_lite/py/decode_test.py
+++ b/pw_hdlc_lite/py/decode_test.py
@@ -14,13 +14,11 @@
 # the License.
 """Contains the Python decoder tests and generates C++ decoder tests."""
 
-from collections import defaultdict
-from pathlib import Path
-from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Union
-from typing import Optional, TextIO, Tuple
+from typing import Iterator, List, NamedTuple, Tuple
 import unittest
-import sys
 
+from pw_build.generated_tests import Context, PyTest, TestGenerator, GroupOrTest
+from pw_build.generated_tests import parse_test_generation_args
 from pw_hdlc_lite.decode import Frame, FrameDecoder, FrameStatus, NO_ADDRESS
 from pw_hdlc_lite.protocol import frame_check_sequence as fcs
 
@@ -54,10 +52,7 @@
     ],
 )
 
-TestCase = Tuple[bytes, List[Expected]]
-TestCases = Tuple[Union[str, TestCase], ...]
-
-TEST_CASES: TestCases = (
+TEST_CASES: Tuple[GroupOrTest[Tuple[bytes, List[Expected]]], ...] = (
     'Empty payload',
     (_encode(0, 0, b''), [Expected(0, b'\0', b'')]),
     (_encode(55, 0x99, b''), [Expected(55, b'\x99', b'')]),
@@ -158,18 +153,7 @@
 )  # yapf: disable
 # Formatting for the above tuple is very slow, so disable yapf.
 
-
-def _sort_test_cases(test_cases: TestCases) -> Dict[str, List[TestCase]]:
-    cases: Dict[str, List[TestCase]] = defaultdict(list)
-    message = ''
-
-    for case in test_cases:
-        if isinstance(case, str):
-            message = case
-        else:
-            cases[message].append(case)
-
-    return cases
+_TESTS = TestGenerator(TEST_CASES)
 
 
 def _expected(frames: List[Frame]) -> Iterator[str]:
@@ -180,23 +164,7 @@
             yield f'      Status::DATA_LOSS,  // Frame {i}'
 
 
-_CPP_HEADER = f"""\
-// Copyright 2020 The Pigweed Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not
-// use this file except in compliance with the License. You may obtain a copy of
-// the License at
-//
-//     https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-// License for the specific language governing permissions and limitations under
-// the License.
-
-// AUTOGENERATED by {Path(__file__).name} - DO NOT EDIT
-
+_CPP_HEADER = """\
 #include "pw_hdlc_lite/decoder.h"
 
 #include <array>
@@ -206,30 +174,22 @@
 #include "gtest/gtest.h"
 #include "pw_bytes/array.h"
 
-namespace pw::hdlc_lite {{
-namespace {{
-
-// clang-format off
+namespace pw::hdlc_lite {
+namespace {
 """
 
 _CPP_FOOTER = """\
-// clang-format on
-
 }  // namespace
-}  // namespace pw::hdlc_lite
-"""
+}  // namespace pw::hdlc_lite"""
 
 
-def _cpp_test(group: str, count: Optional[int], data: bytes) -> Iterator[str]:
+def _cpp_test(ctx: Context) -> Iterator[str]:
     """Generates a C++ test for the provided test data."""
+    data, _ = ctx.test_case
     frames = list(FrameDecoder().process(data))
     data_bytes = ''.join(rf'\x{byte:02x}' for byte in data)
 
-    name = ''.join(w.capitalize() for w in group.replace('-', ' ').split(' '))
-    name = ''.join(c if c.isalnum() else '_' for c in name)
-    name = name if count is None else name + f'_{count}'
-
-    yield f'TEST(Decoder, {name}) {{'
+    yield f'TEST(Decoder, {ctx.cc_name()}) {{'
     yield f'  static constexpr auto kData = bytes::String("{data_bytes}");\n'
 
     for i, frame in enumerate(frames, 1):
@@ -279,31 +239,14 @@
 }}"""
 
 
-def _define_cc_tests(test_cases: TestCases, output: TextIO) -> None:
-    """Writes C++ tests for all test cases."""
-    output.write(_CPP_HEADER)
+def _define_py_test(ctx: Context) -> PyTest:
+    data, expected_frames = ctx.test_case
 
-    for group, test_list in _sort_test_cases(test_cases).items():
-        for i, (data, _) in enumerate(test_list, 1):
-            count = i if len(test_list) > 1 else None
-            for line in _cpp_test(group, count, data):
-                output.write(line)
-                output.write('\n')
-
-            output.write('\n')
-
-    output.write(_CPP_FOOTER)
-
-
-def _define_py_test(group: str,
-                    data: bytes,
-                    expected_frames: List[Expected],
-                    count: int = None) -> Callable[[Any], None]:
     def test(self) -> None:
         # Decode in one call
         self.assertEqual(expected_frames,
                          list(FrameDecoder().process(data)),
-                         msg=f'{group}: {data!r}')
+                         msg=f'{ctx.group}: {data!r}')
 
         # Decode byte-by-byte
         decoder = FrameDecoder()
@@ -313,35 +256,18 @@
 
         self.assertEqual(expected_frames,
                          decoded_frames,
-                         msg=f'{group} (byte-by-byte): {data!r}')
+                         msg=f'{ctx.group} (byte-by-byte): {data!r}')
 
-    name = 'test_' + ''.join(c if c.isalnum() else '_' for c in group.lower())
-    test.__name__ = name if count is None else name + f'_{count}'
     return test
 
 
-def _define_py_tests(test_cases: TestCases = TEST_CASES):
-    """Generates a Python test function for each test case."""
-    tests: Dict[str, Callable[[Any], None]] = {}
-
-    for group, test_list in _sort_test_cases(test_cases).items():
-        for i, (data, expected_frames) in enumerate(test_list, 1):
-            count = i if len(test_list) > 1 else None
-            test = _define_py_test(group, data, expected_frames, count)
-
-            assert test.__name__ not in tests, f'Duplicate! {test.__name__}'
-            tests[test.__name__] = test
-
-    return tests
-
-
 # Class that tests all cases in TEST_CASES.
-DecoderTest = type('DecoderTest', (unittest.TestCase, ), _define_py_tests())
+DecoderTest = _TESTS.python_tests('DecoderTest', _define_py_test)
 
 if __name__ == '__main__':
-    # If --generate-cc-decode-test is provided, generate the C++ test file.
-    if len(sys.argv) >= 2 and sys.argv[1] == '--generate-cc-decode-test':
-        with Path(sys.argv[2]).open('w') as file:
-            _define_cc_tests(TEST_CASES, file)
-    else:  # Otherwise, run the unit tests.
+    args = parse_test_generation_args()
+    if args.generate_cc_test:
+        _TESTS.cc_tests(args.generate_cc_test, _cpp_test, _CPP_HEADER,
+                        _CPP_FOOTER)
+    else:
         unittest.main()
diff --git a/pw_hdlc_lite/py/setup.py b/pw_hdlc_lite/py/setup.py
index 246c2fd..dad879d 100644
--- a/pw_hdlc_lite/py/setup.py
+++ b/pw_hdlc_lite/py/setup.py
@@ -23,4 +23,5 @@
     description='Tools for Encoding/Decoding data using the HDLC-Lite protocol',
     packages=setuptools.find_packages(),
     install_requires=['ipython'],
+    tests_require=['pw_build'],
 )