| #!/usr/bin/env python3 |
| |
| # Copyright 2020 The Pigweed Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy of |
| # the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations under |
| # the License. |
| """Checks and fixes formatting for source files. |
| |
| This uses clang-format, gn format, gofmt, and python -m yapf to format source |
| code. These tools must be available on the path when this script is invoked! |
| """ |
| |
| import argparse |
| import collections |
| import difflib |
| import logging |
| import os |
| from pathlib import Path |
| import re |
| import subprocess |
| import sys |
| import tempfile |
| from typing import Callable, Collection, Dict, Iterable, List, NamedTuple |
| from typing import Optional, Pattern, Tuple, Union |
| |
| try: |
| import pw_presubmit |
| except ImportError: |
| # Append the pw_presubmit package path to the module search path to allow |
| # running this module without installing the pw_presubmit package. |
| sys.path.append(os.path.dirname(os.path.dirname( |
| os.path.abspath(__file__)))) |
| import pw_presubmit |
| |
| import pw_cli.color |
| import pw_cli.env |
| from pw_presubmit.presubmit import FileFilter |
| from pw_presubmit import cli, git_repo |
| from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural |
| |
| _LOG: logging.Logger = logging.getLogger(__name__) |
| _COLOR = pw_cli.color.colors() |
| |
| |
| def _colorize_diff_line(line: str) -> str: |
| if line.startswith('--- ') or line.startswith('+++ '): |
| return _COLOR.bold_white(line) |
| if line.startswith('-'): |
| return _COLOR.red(line) |
| if line.startswith('+'): |
| return _COLOR.green(line) |
| if line.startswith('@@ '): |
| return _COLOR.cyan(line) |
| return line |
| |
| |
| def colorize_diff(lines: Iterable[str]) -> str: |
| """Takes a diff str or list of str lines and returns a colorized version.""" |
| if isinstance(lines, str): |
| lines = lines.splitlines(True) |
| |
| return ''.join(_colorize_diff_line(line) for line in lines) |
| |
| |
| def _diff(path, original: bytes, formatted: bytes) -> str: |
| return colorize_diff( |
| difflib.unified_diff( |
| original.decode(errors='replace').splitlines(True), |
| formatted.decode(errors='replace').splitlines(True), |
| f'{path} (original)', f'{path} (reformatted)')) |
| |
| |
| Formatter = Callable[[str, bytes], bytes] |
| |
| |
| def _diff_formatted(path, formatter: Formatter) -> Optional[str]: |
| """Returns a diff comparing a file to its formatted version.""" |
| with open(path, 'rb') as fd: |
| original = fd.read() |
| |
| formatted = formatter(path, original) |
| |
| return None if formatted == original else _diff(path, original, formatted) |
| |
| |
| def _check_files(files, formatter: Formatter) -> Dict[Path, str]: |
| errors = {} |
| |
| for path in files: |
| difference = _diff_formatted(path, formatter) |
| if difference: |
| errors[path] = difference |
| |
| return errors |
| |
| |
| def _clang_format(*args: str, **kwargs) -> bytes: |
| return log_run(['clang-format', '--style=file', *args], |
| stdout=subprocess.PIPE, |
| check=True, |
| **kwargs).stdout |
| |
| |
| def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]: |
| """Checks formatting; returns {path: diff} for files with bad formatting.""" |
| return _check_files(files, lambda path, _: _clang_format(path)) |
| |
| |
| def clang_format_fix(files: Iterable) -> Dict[Path, str]: |
| """Fixes formatting for the provided files in place.""" |
| _clang_format('-i', *files) |
| return {} |
| |
| |
| def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Checks formatting; returns {path: diff} for files with bad formatting.""" |
| return _check_files( |
| files, lambda _, data: log_run(['gn', 'format', '--stdin'], |
| input=data, |
| stdout=subprocess.PIPE, |
| check=True).stdout) |
| |
| |
| def fix_gn_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Fixes formatting for the provided files in place.""" |
| log_run(['gn', 'format', *files], check=True) |
| return {} |
| |
| |
| def check_bazel_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Checks formatting; returns {path: diff} for files with bad formatting.""" |
| errors: Dict[Path, str] = {} |
| |
| def _format_temp(path: Union[Path, str], data: bytes) -> bytes: |
| # buildifier doesn't have an option to output the changed file, so |
| # copy the file to a temp location, run buildifier on it, read that |
| # modified copy, and return its contents. |
| with tempfile.TemporaryDirectory() as temp: |
| build = Path(temp) / os.path.basename(path) |
| build.write_bytes(data) |
| |
| proc = log_run(['buildifier', build], capture_output=True) |
| if proc.returncode: |
| stderr = proc.stderr.decode(errors='replace') |
| stderr = stderr.replace(str(build), str(path)) |
| errors[Path(path)] = stderr |
| return build.read_bytes() |
| |
| result = _check_files(files, _format_temp) |
| result.update(errors) |
| return result |
| |
| |
| def fix_bazel_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Fixes formatting for the provided files in place.""" |
| errors = {} |
| for path in files: |
| proc = log_run(['buildifier', path], capture_output=True) |
| if proc.returncode: |
| errors[path] = proc.stderr.decode() |
| return errors |
| |
| |
| def check_go_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Checks formatting; returns {path: diff} for files with bad formatting.""" |
| return _check_files( |
| files, lambda path, _: log_run( |
| ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout) |
| |
| |
| def fix_go_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Fixes formatting for the provided files in place.""" |
| log_run(['gofmt', '-w', *files], check=True) |
| return {} |
| |
| |
| def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: |
| return log_run(['python', '-m', 'yapf', '--parallel', *args], |
| capture_output=True, |
| **kwargs) |
| |
| |
| _DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE) |
| |
| |
| def check_py_format(files: Iterable[Path]) -> Dict[Path, str]: |
| """Checks formatting; returns {path: diff} for files with bad formatting.""" |
| process = _yapf('--diff', *files) |
| |
| errors: Dict[Path, str] = {} |
| |
| if process.stdout: |
| raw_diff = process.stdout.decode(errors='replace') |
| |
| matches = tuple(_DIFF_START.finditer(raw_diff)) |
| for start, end in zip(matches, (*matches[1:], None)): |
| errors[Path(start.group(1))] = colorize_diff( |
| raw_diff[start.start():end.start() if end else None]) |
| |
| if process.stderr: |
| _LOG.error('yapf encountered an error:\n%s', |
| process.stderr.decode(errors='replace').rstrip()) |
| errors.update({file: '' for file in files if file not in errors}) |
| |
| return errors |
| |
| |
| def fix_py_format(files: Iterable) -> Dict[Path, str]: |
| """Fixes formatting for the provided files in place.""" |
| _yapf('--in-place', *files, check=True) |
| return {} |
| |
| |
| _TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE) |
| |
| |
| def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]: |
| """Checks for and optionally removes trailing whitespace.""" |
| errors = {} |
| |
| for path in paths: |
| with path.open('rb') as fd: |
| contents = fd.read() |
| |
| corrected = _TRAILING_SPACE.sub(b'', contents) |
| if corrected != contents: |
| errors[path] = _diff(path, contents, corrected) |
| |
| if fix: |
| with path.open('wb') as fd: |
| fd.write(corrected) |
| |
| return errors |
| |
| |
| def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]: |
| return _check_trailing_space(files, fix=False) |
| |
| |
| def fix_trailing_space(files: Iterable[Path]) -> Dict[Path, str]: |
| _check_trailing_space(files, fix=True) |
| return {} |
| |
| |
| def print_format_check(errors: Dict[Path, str], |
| show_fix_commands: bool) -> None: |
| """Prints and returns the result of a check_*_format function.""" |
| if not errors: |
| # Don't print anything in the all-good case. |
| return |
| |
| # Show the format fixing diff suggested by the tooling (with colors). |
| _LOG.warning('Found %d files with formatting errors. Format changes:', |
| len(errors)) |
| for diff in errors.values(): |
| print(diff, end='') |
| |
| # Show a copy-and-pastable command to fix the issues. |
| if show_fix_commands: |
| |
| def path_relative_to_cwd(path): |
| try: |
| return Path(path).resolve().relative_to(Path.cwd().resolve()) |
| except ValueError: |
| return Path(path).resolve() |
| |
| message = (f' pw format --fix {path_relative_to_cwd(path)}' |
| for path in errors) |
| _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message)) |
| |
| |
| class CodeFormat(NamedTuple): |
| language: str |
| filter: FileFilter |
| check: Callable[[Iterable], Dict[Path, str]] |
| fix: Callable[[Iterable], Dict[Path, str]] |
| |
| @property |
| def extensions(self): |
| # TODO(b/23842636): Switch calls of this to using "filter" and remove. |
| return self.filter.endswith |
| |
| |
| CPP_HEADER_EXTS = frozenset(('.h', '.hpp', '.hxx', '.h++', '.hh', '.H')) |
| CPP_SOURCE_EXTS = frozenset( |
| ('.c', '.cpp', '.cxx', '.c++', '.cc', '.C', '.inc', '.inl')) |
| CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS) |
| CPP_FILE_FILTER = FileFilter(endswith=CPP_EXTS, |
| exclude=(r'\.pb\.h$', r'\.pb\.c$')) |
| |
| C_FORMAT = CodeFormat('C and C++', CPP_FILE_FILTER, clang_format_check, |
| clang_format_fix) |
| |
| PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', |
| FileFilter(endswith=('.proto', )), |
| clang_format_check, clang_format_fix) |
| |
| JAVA_FORMAT: CodeFormat = CodeFormat('Java', FileFilter(endswith=('.java', )), |
| clang_format_check, clang_format_fix) |
| |
| JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', |
| FileFilter(endswith=('.js', )), |
| clang_format_check, |
| clang_format_fix) |
| |
| GO_FORMAT: CodeFormat = CodeFormat('Go', FileFilter(endswith=('.go', )), |
| check_go_format, fix_go_format) |
| |
| PYTHON_FORMAT: CodeFormat = CodeFormat('Python', |
| FileFilter(endswith=('.py', )), |
| check_py_format, fix_py_format) |
| |
| GN_FORMAT: CodeFormat = CodeFormat('GN', FileFilter(endswith=('.gn', '.gni')), |
| check_gn_format, fix_gn_format) |
| |
| BAZEL_FORMAT: CodeFormat = CodeFormat( |
| 'Bazel', |
| FileFilter(endswith=('BUILD', '.bazel', '.bzl'), name=('WORKSPACE')), |
| check_bazel_format, fix_bazel_format) |
| |
| COPYBARA_FORMAT: CodeFormat = CodeFormat('Copybara', |
| FileFilter(endswith=('.bara.sky', )), |
| check_bazel_format, fix_bazel_format) |
| |
| # TODO(b/234881054): Add real code formatting support for CMake |
| CMAKE_FORMAT: CodeFormat = CodeFormat( |
| 'CMake', FileFilter(endswith=('CMakeLists.txt', '.cmake')), |
| check_trailing_space, fix_trailing_space) |
| |
| RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', |
| FileFilter(endswith=('.rst', )), |
| check_trailing_space, fix_trailing_space) |
| |
| MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', |
| FileFilter(endswith=('.md', )), |
| check_trailing_space, |
| fix_trailing_space) |
| |
| CODE_FORMATS: Tuple[CodeFormat, ...] = ( |
| C_FORMAT, |
| JAVA_FORMAT, |
| JAVASCRIPT_FORMAT, |
| PROTO_FORMAT, |
| GO_FORMAT, |
| PYTHON_FORMAT, |
| GN_FORMAT, |
| BAZEL_FORMAT, |
| COPYBARA_FORMAT, |
| CMAKE_FORMAT, |
| RST_FORMAT, |
| MARKDOWN_FORMAT, |
| ) |
| |
| |
| def presubmit_check( |
| code_format: CodeFormat, |
| *, |
| exclude: Collection[Union[str, Pattern[str]]] = ()) -> Callable: |
| """Creates a presubmit check function from a CodeFormat object. |
| |
| Args: |
| exclude: Additional exclusion regexes to apply. |
| """ |
| |
| # Make a copy of the FileFilter and add in any additional excludes. |
| file_filter = FileFilter(**vars(code_format.filter)) |
| file_filter.exclude += tuple(re.compile(e) for e in exclude) |
| |
| @pw_presubmit.filter_paths(file_filter=file_filter) |
| def check_code_format(ctx: pw_presubmit.PresubmitContext): |
| errors = code_format.check(ctx.paths) |
| print_format_check( |
| errors, |
| # When running as part of presubmit, show the fix command help. |
| show_fix_commands=True, |
| ) |
| if errors: |
| raise pw_presubmit.PresubmitFailure |
| |
| language = code_format.language.lower().replace('+', 'p').replace(' ', '_') |
| check_code_format.__name__ = f'{language}_format' |
| |
| return check_code_format |
| |
| |
| def presubmit_checks( |
| *, |
| exclude: Collection[Union[str, Pattern[str]]] = (), |
| code_formats: Collection[CodeFormat] = CODE_FORMATS |
| ) -> Tuple[Callable, ...]: |
| """Returns a tuple with all supported code format presubmit checks. |
| |
| Args: |
| exclude: Additional exclusion regexes to apply. |
| code_formats: A list of CodeFormat objects to run checks with. |
| """ |
| |
| return tuple(presubmit_check(fmt, exclude=exclude) for fmt in code_formats) |
| |
| |
| class CodeFormatter: |
| """Checks or fixes the formatting of a set of files.""" |
| def __init__(self, |
| files: Iterable[Path], |
| code_formats: Collection[CodeFormat] = CODE_FORMATS): |
| self.paths = list(files) |
| self._formats: Dict[CodeFormat, List] = collections.defaultdict(list) |
| |
| for path in self.paths: |
| for code_format in code_formats: |
| if code_format.filter.matches(path): |
| _LOG.debug('Formatting %s as %s', path, |
| code_format.language) |
| self._formats[code_format].append(path) |
| break |
| else: |
| _LOG.debug('No formatter found for %s', path) |
| |
| def check(self) -> Dict[Path, str]: |
| """Returns {path: diff} for files with incorrect formatting.""" |
| errors: Dict[Path, str] = {} |
| |
| for code_format, files in self._formats.items(): |
| _LOG.debug('Checking %s', ', '.join(str(f) for f in files)) |
| errors.update(code_format.check(files)) |
| |
| return collections.OrderedDict(sorted(errors.items())) |
| |
| def fix(self) -> Dict[Path, str]: |
| """Fixes format errors for supported files in place.""" |
| all_errors: Dict[Path, str] = {} |
| for code_format, files in self._formats.items(): |
| errors = code_format.fix(files) |
| if errors: |
| for path, error in errors.items(): |
| _LOG.error('Failed to format %s', path) |
| for line in error.splitlines(): |
| _LOG.error('%s', line) |
| all_errors.update(errors) |
| continue |
| |
| _LOG.info('Formatted %s', |
| plural(files, code_format.language + ' file')) |
| return all_errors |
| |
| |
| def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]: |
| try: |
| return file_summary( |
| Path(f).resolve().relative_to(base.resolve()) for f in files) |
| except ValueError: |
| return [] |
| |
| |
| def format_paths_in_repo( |
| paths: Collection[Union[Path, str]], |
| exclude: Collection[Pattern[str]], |
| fix: bool, |
| base: str, |
| code_formats: Collection[CodeFormat] = CODE_FORMATS) -> int: |
| """Checks or fixes formatting for files in a Git repo.""" |
| |
| files = [Path(path).resolve() for path in paths if os.path.isfile(path)] |
| repo = git_repo.root() if git_repo.is_repo() else None |
| |
| # Implement a graceful fallback in case the tracking branch isn't available. |
| if (base == git_repo.TRACKING_BRANCH_ALIAS |
| and not git_repo.tracking_branch(repo)): |
| _LOG.warning( |
| 'Failed to determine the tracking branch, using --base HEAD~1 ' |
| 'instead of listing all files') |
| base = 'HEAD~1' |
| |
| # If this is a Git repo, list the original paths with git ls-files or diff. |
| if repo: |
| project_root = Path(pw_cli.env.pigweed_environment().PW_PROJECT_ROOT) |
| _LOG.info( |
| 'Formatting %s', |
| git_repo.describe_files(repo, Path.cwd(), base, paths, exclude, |
| project_root)) |
| |
| # Add files from Git and remove duplicates. |
| files = sorted( |
| set(exclude_paths(exclude, git_repo.list_files(base, paths))) |
| | set(files)) |
| elif base: |
| _LOG.critical( |
| 'A base commit may only be provided if running from a Git repo') |
| return 1 |
| |
| return format_files(files, fix, repo=repo, code_formats=code_formats) |
| |
| |
| def format_files(paths: Collection[Union[Path, str]], |
| fix: bool, |
| repo: Optional[Path] = None, |
| code_formats: Collection[CodeFormat] = CODE_FORMATS) -> int: |
| """Checks or fixes formatting for the specified files.""" |
| |
| formatter = CodeFormatter(files=(Path(p) for p in paths), |
| code_formats=code_formats) |
| |
| _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file')) |
| |
| for line in _file_summary(paths, repo if repo else Path.cwd()): |
| print(line, file=sys.stderr) |
| |
| check_errors = formatter.check() |
| print_format_check(check_errors, show_fix_commands=(not fix)) |
| |
| if check_errors: |
| if fix: |
| _LOG.info('Applying formatting fixes to %d files', |
| len(check_errors)) |
| fix_errors = formatter.fix() |
| if fix_errors: |
| _LOG.info('Failed to apply formatting fixes') |
| print_format_check(fix_errors, show_fix_commands=False) |
| return 1 |
| |
| _LOG.info('Formatting fixes applied successfully') |
| return 0 |
| |
| _LOG.error('Formatting errors found') |
| return 1 |
| |
| _LOG.info('Congratulations! No formatting changes needed') |
| return 0 |
| |
| |
| def arguments(git_paths: bool) -> argparse.ArgumentParser: |
| """Creates an argument parser for format_files or format_paths_in_repo.""" |
| |
| parser = argparse.ArgumentParser(description=__doc__) |
| |
| if git_paths: |
| cli.add_path_arguments(parser) |
| else: |
| |
| def existing_path(arg: str) -> Path: |
| path = Path(arg) |
| if not path.is_file(): |
| raise argparse.ArgumentTypeError( |
| f'{arg} is not a path to a file') |
| |
| return path |
| |
| parser.add_argument('paths', |
| metavar='path', |
| nargs='+', |
| type=existing_path, |
| help='File paths to check') |
| |
| parser.add_argument('--fix', |
| action='store_true', |
| help='Apply formatting fixes in place.') |
| return parser |
| |
| |
| def main() -> int: |
| """Check and fix formatting for source files.""" |
| return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) |
| |
| |
| def _pigweed_upstream_main() -> int: |
| """Check and fix formatting for source files in upstream Pigweed. |
| |
| Excludes third party sources. |
| """ |
| args = arguments(git_paths=True).parse_args() |
| |
| # Exclude paths with third party code from formatting. |
| args.exclude.append(re.compile('^third_party/fuchsia/repo/')) |
| |
| return format_paths_in_repo(**vars(args)) |
| |
| |
| if __name__ == '__main__': |
| try: |
| # If pw_cli is available, use it to initialize logs. |
| from pw_cli import log |
| |
| log.install(logging.INFO) |
| except ImportError: |
| # If pw_cli isn't available, display log messages like a simple print. |
| logging.basicConfig(format='%(message)s', level=logging.INFO) |
| |
| sys.exit(main()) |