blob: cd1f049904d13b3fedb18cc5bc52b35f41922f19 [file] [log] [blame]
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Bazel-related functions."""
from __future__ import annotations
import dataclasses
import shlex
import re
from typing import Any, Sequence, TYPE_CHECKING
from PB.recipe_modules.pigweed.bazel.options import Options
from recipe_engine import recipe_api
if TYPE_CHECKING: # pragma: no cover
from recipe_engine import config_types
from RECIPE_MODULES.pigweed.checkout import api as checkout_api
@dataclasses.dataclass
class BazelRunner:
api: recipe_api.RecipeApi
checkout: checkout_api.CheckoutContext
options: Options
_bazel: config_types.Path | None = None
def _ensure_bazelisk(self) -> config_types.Path:
ensure_file = self.api.cipd.EnsureFile()
ensure_file.add_package(
'fuchsia/third_party/bazelisk/${platform}',
self.options.bazelisk_version or 'latest',
)
root = self.api.path.mkdtemp()
self.api.cipd.ensure(root, ensure_file, name='ensure bazelisk')
return root / 'bazelisk'
def ensure(self) -> config_types.Path:
if self._bazel:
return self._bazel
self._bazel = self._ensure_bazelisk()
self.api.step('bazel version', [self._bazel, 'version'])
return self._bazel
def _override_args(self) -> list[str]:
if self.api.path.exists(self.checkout.root / 'MODULE.bazel'):
# We're in a bzlmod-managed workspace.
flag = "--override_module" # pragma: no cover
else:
# We're in a traditional workspace.
flag = "--override_repository"
return [
f'{flag}={repo}={path}'
for repo, path in self.checkout.bazel_overrides.items()
]
def run(self, **kwargs) -> None:
config_name = self.options.config_path or 'pigweed.json'
config_path = self.checkout.root / config_name
self.api.path.mock_add_file(config_path)
config = {}
if self.api.path.isfile(config_path):
config = self.api.file.read_json(
f'read {config_name}',
config_path,
test_data={
'pw': {
'bazel_presubmit': {
'remote_cache': True,
'upload_local_results': True,
'programs': {
'default': [
['build', '//...'],
['test', '//...'],
],
},
},
},
},
)
config = config.get('pw', config).get('bazel_presubmit', config)
base_args: list[str] = []
if config.get('remote_cache'):
base_args.append('--config=remote_cache')
if config.get('upload_local_results'):
if self.api.buildbucket_util.is_tryjob:
self.api.step.empty(
'ignoring upload_local_results because this is a tryjob'
)
elif not config.get('remote_cache'):
self.api.step.empty(
'ignoring upload_local_results since remote_cache is False'
)
else:
base_args.append('--remote_upload_local_results=true')
base_args.extend(self._override_args())
success = True
with (
self.api.context(cwd=self.checkout.root),
self.api.defer.context() as defer,
):
for invocation in self.options.invocations:
assert invocation.args
name: str = ' '.join(['bazel'] + list(invocation.args))
defer(
self.api.step,
name,
[self.ensure(), *invocation.args, *base_args],
**kwargs,
)
programs = config.get('programs', {})
for program in self.options.program:
with self.api.step.nest(program):
assert program in programs, f'{program} not in {programs}'
assert programs[program]
for args in programs[program]:
cmd = [self.ensure(), *args, *base_args]
defer(
self.api.step,
shlex.join(args),
cmd,
**kwargs,
)
def nwise(iterable, n):
# nwise('ABCDEFG', 3) → ABC BCD CDE DEF EFG
# See also
# https://docs.python.org/3/library/itertools.html#itertools.pairwise
iterator = iter(iterable)
initial_items = [None]
for i in range(1, n):
initial_items.append(next(iterator, None))
items = tuple(initial_items)
for x in iterator:
items = (*items[1:], x)
yield items
def trim_nearby_lines(lines: Sequence[LineProxy]) -> list[LineProxy]:
# If there's a blank line in the second half, remove any following lines.
for i in range(len(lines) // 2, len(lines)):
text = lines[i].text or ''
if not text.strip():
lines = lines[0:i]
break
# If there's a blank line in the first half, remove any preceding lines.
for i in range(len(lines) // 2, 0, -1):
text = lines[i].text or ''
if not text.strip():
lines = lines[i:]
break
return lines
def proximity_sort_nearby_lines(
lines: Sequence[LineProxy],
center: int,
) -> list[LineProxy]:
# Shift the order to be center-out instead of ascending.
lines = [(abs(center - x.idx), x.idx, x) for x in lines]
return [x[-1] for x in sorted(lines)]
@dataclasses.dataclass
class UpdateCommitHashResult:
old_revision: str
project_name: str | None
class LineProxy:
def __init__(self, lines, idx):
self._lines = lines
self._idx = idx
@property
def idx(self):
return self._idx
@property
def text(self):
return self._lines[self._idx]
@text.setter
def text(self, value):
self._lines[self._idx] = value
def __repr__(self):
return f'{self.idx} {self.text}'
def proxy(lines):
return [LineProxy(lines, i) for i in range(len(lines))]
class BazelApi(recipe_api.RecipeApi):
"""Bazel utilities."""
BazelRunner = BazelRunner
UpdateCommitHashResult = UpdateCommitHashResult
def new_runner(
self,
checkout: checkout_api.CheckoutContext,
options: Options | None,
) -> BazelRunner:
return BazelRunner(self.m, checkout, options)
def _read(self, path: config_types.Path, num_nearby_lines: int):
lines = [''] * num_nearby_lines
lines.extend(
self.m.file.read_text(
f'read old {path.name}',
path,
test_data=self.m.bazel.test_api.TEST_WORKSPACE_FILE,
)
.strip()
.splitlines()
)
lines.extend([''] * num_nearby_lines)
return lines
def _get_matching_groups(
self,
checkout: checkout_api.CheckoutContext,
lines: Sequence[str],
num_nearby_lines: int,
project_remote: str,
replace_remote: bool = False,
) -> list[tuple[LineProxy, list[LineProxy]]]:
matching_groups: list[tuple[LineProxy, list[LineProxy]]] = []
for nearby_lines in nwise(proxy(lines), num_nearby_lines * 2 + 1):
curr = nearby_lines[len(nearby_lines) // 2]
match = re.search(
r'^\s*remote\s*=\s*"(?P<remote>[^"]+)",?\s*$', curr.text
)
if not match:
continue
match_remote = match.group('remote')
if checkout.remotes_equivalent(match_remote, project_remote):
pres = self.m.step.empty(
f'found equivalent remote {match_remote!r}'
).presentation
pres.logs['lines'] = [repr(x) for x in nearby_lines]
matching_groups.append((curr, nearby_lines))
if replace_remote and match_remote != project_remote:
curr.text = curr.text.replace(match_remote, project_remote)
else:
pres = self.m.step.empty(
f'found other remote {match_remote!r}'
).presentation
pres.logs['lines'] = [repr(x) for x in nearby_lines]
return matching_groups
def _process_nearby_lines(self, nearby_lines):
pres = self.m.step.empty('lines').presentation
center = nearby_lines[len(nearby_lines) // 2].idx
pres.logs['0_center'] = [str(center)]
pres.logs['1_orig'] = [repr(x) for x in nearby_lines]
nearby_lines = trim_nearby_lines(nearby_lines)
pres.logs['2_trimmed'] = [repr(x) for x in nearby_lines]
nearby_lines = proximity_sort_nearby_lines(nearby_lines, center)
pres.logs['3_sorted'] = [repr(x) for x in nearby_lines]
return nearby_lines
def retrieve_git_repository_attributes(
self,
checkout: checkout_api.CheckoutContext,
project_remote: str,
num_nearby_lines: int = 10,
path: config_types.Path | None = None,
) -> list[dict[str, str]]:
if not path:
path = checkout.root / 'WORKSPACE'
lines = self._read(path, num_nearby_lines)
matching_groups = self._get_matching_groups(
checkout=checkout,
lines=lines,
num_nearby_lines=num_nearby_lines,
project_remote=project_remote,
)
results: list[dict[str, str]] = []
for _, nearby_lines in matching_groups:
entry = {}
nearby_lines = self._process_nearby_lines(nearby_lines)
for line in nearby_lines:
if match := re.search(
r'^\s*([\w_]+)\s*=\s*(\S.*),\s*$',
line.text,
):
entry[match.group(1)] = match.group(2).strip('"')
if match.group(1) == 'module_name' and 'name' not in entry:
entry['name'] = match.group(2).strip('"')
if entry:
results.append(entry)
return results
def update_commit_hash(
self,
*,
checkout: checkout_api.CheckoutContext,
project_remote: str,
new_revision: str,
num_nearby_lines: int = 10,
path: config_types.Path | None,
replace_remote: bool = False,
) -> UpdateCommitHashResult | None:
if not path:
path = checkout.root / 'WORKSPACE'
lines = self._read(path, num_nearby_lines)
matching_groups = self._get_matching_groups(
checkout=checkout,
lines=lines,
num_nearby_lines=num_nearby_lines,
project_remote=project_remote,
replace_remote=replace_remote,
)
if not matching_groups:
self.m.step.empty(
f'could not find remote {project_remote} in {path}',
)
return None
project_names: list[str] = []
for matching_line, nearby_lines in matching_groups:
nearby_lines = self._process_nearby_lines(nearby_lines)
commit_rx = re.compile(
r'^(?P<prefix>\s*commit\s*=\s*")'
r'(?P<commit>[0-9a-f]{40})'
r'(?P<suffix>",?\s*)$'
)
for line in nearby_lines:
if match := commit_rx.search(line.text):
idx = line.idx
break
else:
self.m.step.empty(
'could not find commit line adjacent to '
f'{matching_line.text!r} in {path}',
)
return None
old_revision = match.group('commit')
prefix = match.group("prefix")
suffix = match.group("suffix")
lines[idx] = f'{prefix}{new_revision}{suffix}'
# Remove all existing metadata lines in this git_repository() entry.
idx2 = idx - 1
while lines[idx2].strip().startswith('# ROLL: '):
lines[idx2] = None
idx2 -= 1
ws_prefix = re.search(r'^\s*', prefix).group(0)
comment_prefix = f'{ws_prefix}# ROLL: '
now = self.m.time.utcnow().strftime('%Y-%m-%d')
comment_lines = (
f'{comment_prefix}Warning: this entry is automatically '
'updated.',
f'{comment_prefix}Last updated {now}.',
f'{comment_prefix}By {self.m.buildbucket.build_url()}.',
)
lines[idx] = '\n'.join(comment_lines + (lines[idx],))
for line in nearby_lines:
if match := re.search(
r'^\s*(?:module_)?name\s*=\s*"(?P<name>[^"]+)",?\s*$', line.
text or '',
):
project_names.append(match.group('name'))
break
self.m.file.write_text(
f'write new {path.name}',
path,
''.join(
f'{x}\n'
for x in lines[num_nearby_lines:-num_nearby_lines]
if x is not None
),
)
return UpdateCommitHashResult(
old_revision=old_revision,
project_name=', '.join(project_names),
)