| # Copyright 2024 The Pigweed Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy of |
| # the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations under |
| # the License. |
| """Tests for the formatter core.""" |
| |
| from pathlib import Path |
| from tempfile import TemporaryDirectory |
| import unittest |
| |
| from pw_presubmit.format.core import ( |
| FileChecker, |
| FormattedDiff, |
| FormattedFileContents, |
| ) |
| |
| |
| class FakeFileChecker(FileChecker): |
| FORMAT_MAP = { |
| 'foo': 'bar', |
| 'bar': 'bar', |
| 'baz': '\nbaz\n', |
| 'new\n': 'newer\n', |
| } |
| |
| def format_file_in_memory( |
| self, file_path: Path, file_contents: bytes |
| ) -> FormattedFileContents: |
| error = '' |
| formatted = self.FORMAT_MAP.get(file_contents.decode(), None) |
| if formatted is None: |
| error = f'I do not know how to "{file_contents.decode()}".' |
| return FormattedFileContents( |
| ok=not error, |
| formatted_file_contents=formatted.encode() |
| if formatted is not None |
| else b'', |
| error_message=error, |
| ) |
| |
| |
| def _check_files( |
| formatter: FileChecker, file_contents: dict[str, str], dry_run=False |
| ) -> list[FormattedDiff]: |
| with TemporaryDirectory() as tmp: |
| paths = [] |
| for f in file_contents.keys(): |
| file_path = Path(tmp) / f |
| file_path.write_bytes(file_contents[f].encode()) |
| paths.append(file_path) |
| |
| return list(formatter.get_formatting_diffs(paths, dry_run)) |
| |
| |
| class TestFormatCore(unittest.TestCase): |
| """Tests for the format core.""" |
| |
| def setUp(self) -> None: |
| self.formatter = FakeFileChecker() |
| |
| def test_check_files(self): |
| """Tests that check_files() produces diffs as intended.""" |
| file_contents = { |
| 'foo.txt': 'foo', |
| 'bar.txt': 'bar', |
| 'baz.txt': 'baz', |
| 'yep.txt': 'new\n', |
| } |
| expected_diffs = { |
| 'foo.txt': '\n'.join( |
| ( |
| '-foo', |
| '+bar', |
| ' No newline at end of file', |
| ) |
| ), |
| 'baz.txt': '\n'.join( |
| ( |
| '+', |
| ' baz', |
| '-No newline at end of file', |
| ) |
| ), |
| 'yep.txt': '\n'.join( |
| ( |
| '-new', |
| '+newer', |
| ) |
| ), |
| } |
| |
| for result in _check_files(self.formatter, file_contents): |
| filename = result.file_path.name |
| self.assertIn(filename, expected_diffs) |
| self.assertTrue(result.ok) |
| lines = result.diff.splitlines() |
| self.assertEqual( |
| lines.pop(0), f'--- {result.file_path} (original)' |
| ) |
| self.assertEqual( |
| lines.pop(0), f'+++ {result.file_path} (reformatted)' |
| ) |
| self.assertTrue(lines.pop(0).startswith('@@')) |
| |
| self.assertMultiLineEqual( |
| '\n'.join(lines), expected_diffs[filename] |
| ) |
| expected_diffs.pop(filename) |
| |
| self.assertFalse(expected_diffs) |
| |
| def test_check_files_error(self): |
| """Tests that check_files() propagates error messages.""" |
| file_contents = { |
| 'foo.txt': 'broken', |
| 'bar.txt': 'bar', |
| } |
| expected_errors = { |
| 'foo.txt': '\n'.join(('I do not know how to "broken".',)), |
| } |
| for result in _check_files(self.formatter, file_contents): |
| filename = result.file_path.name |
| self.assertIn(filename, expected_errors) |
| self.assertFalse(result.ok) |
| self.assertEqual(result.diff, '') |
| self.assertEqual(result.error_message, expected_errors[filename]) |
| expected_errors.pop(filename) |
| |
| self.assertFalse(expected_errors) |
| |
| def test_check_files_dry_run(self): |
| """Tests that check_files() dry run produces no delta.""" |
| file_contents = { |
| 'foo.txt': 'foo', |
| 'bar.txt': 'bar', |
| 'baz.txt': 'baz', |
| 'yep.txt': 'new\n', |
| } |
| result = _check_files(self.formatter, file_contents, dry_run=True) |
| self.assertFalse(result) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |