blob: c25f0a9a84aada2f2c11170b3fd0d763706ac660 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2019 The Pigweed Authors
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Checks and fixes formatting for source files.
This uses clang-format, gn format, gofmt, and python -m yapf to format source
code. These tools must be available on the path when this script is invoked!
import argparse
import collections
import difflib
import logging
import os
from pathlib import Path
import re
import subprocess
import sys
from typing import Callable, Collection, Dict, Iterable, List, NamedTuple
from typing import Optional, Sequence
import pw_presubmit
except ImportError:
# Append the pw_presubmit package path to the module search path to allow
# running this module without installing the pw_presubmit package.
import pw_presubmit
from pw_presubmit import file_summary, list_git_files, log_run, plural
_LOG: logging.Logger = logging.getLogger(__name__)
def _colorize_diff_line(line: str) -> str:
if line.startswith('--- ') or line.startswith('+++ '):
return pw_presubmit.color_bold_white(line)
if line.startswith('-'):
return pw_presubmit.color_red(line)
if line.startswith('+'):
return pw_presubmit.color_green(line)
if line.startswith('@@ '):
return pw_presubmit.color_aqua(line)
return line
def colorize_diff(lines: Iterable[str]) -> str:
"""Takes a diff str or list of str lines and returns a colorized version."""
if isinstance(lines, str):
lines = lines.splitlines(True)
return ''.join(_colorize_diff_line(line) for line in lines)
def _diff(path, original: bytes, formatted: bytes) -> str:
return colorize_diff(
f'{path} (original)', f'{path} (reformatted)'))
Formatter = Callable[[str, bytes], bytes]
def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
"""Returns a diff comparing a file to its formatted version."""
with open(path, 'rb') as fd:
original =
formatted = formatter(path, original)
return None if formatted == original else _diff(path, original, formatted)
def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
errors = {}
for path in files:
difference = _diff_formatted(path, formatter)
if difference:
errors[path] = difference
return errors
def _clang_format(*args: str, **kwargs) -> bytes:
return log_run('clang-format',
def check_c_format(files: Iterable[Path]) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
return _check_files(files, lambda path, _: _clang_format(path))
def fix_c_format(files: Iterable) -> None:
"""Fixes formatting for the provided files in place."""
_clang_format('-i', *files)
def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
return _check_files(
files, lambda _, data: log_run('gn',
def fix_gn_format(files: Iterable[Path]) -> None:
"""Fixes formatting for the provided files in place."""
log_run('gn', 'format', *files, check=True)
def check_go_format(files: Iterable[Path]) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
return _check_files(
files, lambda path, _: log_run(
'gofmt', path, stdout=subprocess.PIPE, check=True).stdout)
def fix_go_format(files: Iterable[Path]) -> None:
"""Fixes formatting for the provided files in place."""
log_run('gofmt', '-w', *files, check=True)
def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
return log_run('python',
_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
def check_py_format(files: Iterable[Path]) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
process = _yapf('--diff', *files)
errors: Dict[Path, str] = {}
if process.stdout:
raw_diff = process.stdout.decode(errors='replace')
matches = tuple(_DIFF_START.finditer(raw_diff))
for start, end in zip(matches, (*matches[1:], None)):
errors[Path(] = colorize_diff(
raw_diff[start.start():end.start() if end else None])
if process.stderr:
_LOG.error('yapf encountered an error:\n%s',
errors.update({file: '' for file in files if file not in errors})
return errors
def fix_py_format(files: Iterable):
"""Fixes formatting for the provided files in place."""
_yapf('--in-place', *files, check=True)
def print_format_check(
errors: Dict[Path, str],
show_fix_commands: bool,
) -> None:
"""Prints and returns the result of a check_*_format function."""
if not errors:
# Don't print anything in the all-good case.
# Show the format fixing diff suggested by the tooling (with colors).
_LOG.warning('Found %d files with formatting errors. Format changes:',
for diff in errors.values():
print(diff, end='')
# Show a copy-and-pastable command to fix the issues.
if show_fix_commands:
def path_relative_to_cwd(path):
return Path(path).resolve().relative_to(Path.cwd().resolve())
except ValueError:
return Path(path).resolve()
message = (f' pw format --fix {path_relative_to_cwd(path)}'
for path in errors)
_LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
class CodeFormat(NamedTuple):
language: str
extensions: Collection[str]
check: Callable[[Iterable], Dict[Path, str]]
fix: Callable[[Iterable], None]
C_FORMAT: CodeFormat = CodeFormat(
'C and C++', frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp']),
check_c_format, fix_c_format)
GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), check_gn_format,
GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), check_go_format,
PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), check_py_format,
CODE_FORMATS: Sequence[CodeFormat] = (
def presubmit_check(code_format: CodeFormat) -> Callable:
"""Creates a presubmit check function from a CodeFormat object."""
def check_code_format(ctx: pw_presubmit.PresubmitContext):
errors = code_format.check(ctx.paths)
# When running as part of presubmit, show the fix command help.
if errors:
raise pw_presubmit.PresubmitFailure
language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
check_code_format.__name__ = f'{language}_format'
return check_code_format
PRESUBMIT_CHECKS: Sequence[Callable] = tuple(
presubmit_check(code_format) for code_format in CODE_FORMATS)
class CodeFormatter:
"""Checks or fixes the formatting of a set of files."""
def __init__(self, files: Sequence[Path]):
self.paths = list(files)
self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
for path in files:
for code_format in CODE_FORMATS:
if any(str(path).endswith(e) for e in code_format.extensions):
def check(self) -> Dict[Path, str]:
"""Returns {path: diff} for files with incorrect formatting."""
errors: Dict[Path, str] = {}
for code_format, files in self._formats.items():
_LOG.debug('Checking %s', ', '.join(str(f) for f in files))
return collections.OrderedDict(sorted(errors.items()))
def fix(self) -> None:
"""Fixes format errors for supported files in place."""
for code_format, files in self._formats.items():
code_format.fix(files)'Formatted %s',
plural(files, code_format.language + ' file'))
def _file_summary(files: Iterable[Path], base: Path) -> List[str]:
return file_summary(f.resolve().relative_to(base.resolve())
for f in files)
except ValueError:
return []
def main(paths: Sequence[Path], exclude, base: str, fix: bool) -> int:
"""Checks or fixes formatting for files in a Git repo."""
files = [path.resolve() for path in paths if path.is_file()]
# If this is a Git repo, list the original paths with git ls-files or diff.
if pw_presubmit.is_git_repo():
repo = pw_presubmit.git_repo_path()
if repo.samefile(Path.cwd()):'Checking files in the %s repository', repo)
'Checking files in the %s subdirectory of the %s repository',
Path.cwd().relative_to(repo), repo)
# Add files from Git and remove duplicates.
files = sorted(set(list_git_files(base, paths, exclude)) | set(files))
elif base:
'A base commit may only be provided if running from a Git repo')
return 1
formatter = CodeFormatter(files)'Checking formatting for %s', plural(formatter.paths, 'file'))
_LOG.debug('Files to format:\n%s', '\n'.join(str(f) for f in files))
for line in _file_summary(
files, repo if pw_presubmit.is_git_repo() else Path.cwd()):
print(line, file=sys.stderr)
errors = formatter.check()
print_format_check(errors, show_fix_commands=(not fix))
if errors:
if fix:
# TODO: This should perhaps check that the fixes were successful.'Formatting fixes applied successfully')
return 0
_LOG.error('Formatting errors found')
return 1'Congratulations! No formatting changes needed')
return 0
def argument_parser(parser=None) -> argparse.ArgumentParser:
if parser is None:
parser = argparse.ArgumentParser(description=__doc__)
help='Apply formatting fixes in place.')
return parser
if __name__ == '__main__':
logging.basicConfig(format='%(message)s', level=logging.INFO)