blob: 2b9d573309b8ae2de39536ade5f556039b907c41 [file] [log] [blame]
# Copyright 2024 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
from __future__ import annotations
import configparser
import dataclasses
import io
import re
from typing import Callable, Generator, Sequence, TYPE_CHECKING
from PB.recipe_modules.pigweed.bazel.options import Options as BazelOptions
from recipe_engine import recipe_api
if TYPE_CHECKING: # pragma: no cover
from PB.recipe_modules.pigweed.bazel_roll.git_repository import (
GitRepository,
)
from recipe_engine import config_types
from RECIPE_MODULES.fuchsia.git_roll_util import api as git_roll_util_api
from RECIPE_MODULES.pigweed.checkout import api as checkout_api
def nwise(iterable, n):
# nwise('ABCDEFG', 3) → ABC BCD CDE DEF EFG
# See also
# https://docs.python.org/3/library/itertools.html#itertools.pairwise
iterator = iter(iterable)
initial_items = [None]
for i in range(1, n):
initial_items.append(next(iterator, None))
items = tuple(initial_items)
for x in iterator:
items = (*items[1:], x)
yield items
def trim_nearby_lines(lines: Sequence[LineProxy]) -> list[LineProxy]:
# If there's a blank line in the second half, remove any following lines.
for i in range(len(lines) // 2, len(lines)):
text = lines[i].text or ''
if not text.strip():
lines = lines[0:i]
break
# If there's a blank line in the first half, remove any preceding lines.
for i in range(len(lines) // 2, 0, -1):
text = lines[i].text or ''
if not text.strip():
lines = lines[i:]
break
return lines
def proximity_sort_nearby_lines(
lines: Sequence[LineProxy],
center: int,
) -> list[LineProxy]:
# Shift the order to be center-out instead of ascending.
lines = [(abs(center - x.idx), x.idx, x) for x in lines]
return [x[-1] for x in sorted(lines)]
@dataclasses.dataclass
class UpdateCommitHashResult:
old_revision: str
project_name: str | None
finalize: Callable[[], None] | None = None
class LineProxy:
def __init__(self, lines, idx):
self._lines = lines
self._idx = idx
@property
def idx(self):
return self._idx
@property
def text(self):
return self._lines[self._idx]
@text.setter
def text(self, value):
self._lines[self._idx] = value
def __repr__(self):
return f'{self.idx} {self.text}'
def proxy(lines):
return [LineProxy(lines, i) for i in range(len(lines))]
class BazelRollApi(recipe_api.RecipeApi):
UpdateCommitHashResult = UpdateCommitHashResult
def workspace_path(
self,
root: config_types.Path,
value: str,
) -> config_types.Path:
"""Figure out the location of the WORKSPACE or MODULE.bazel file.
If value is '', look for root / 'WORKSPACE' and root / 'MODULE.bazel'.
If exactly one exists, return it. If not, error out.
If root / value is a file, return it.
If root / value is a directory, set root = root / value and apply the
above logic. This enables applying this logic to subdirectories.
Args:
api: Recipe API object.
root: Checkout root.
value: Relative path specified in properties.
Returns:
Path to the WORKSPACE or MODULE.bazel file.
"""
if value:
value_path = root / value
self.m.path.mock_add_file(value_path)
if self.m.path.isfile(value_path):
return value_path
elif self.m.path.isdir(value_path): # pragma: no cover
root = value_path
else:
self.m.step.empty( # pragma: no cover
f'{value_path} does not exist',
status='FAILURE',
)
workspace = root / 'WORKSPACE'
module_bazel = root / 'MODULE.bazel'
self.m.path.mock_add_file(workspace)
if self.m.path.isfile(module_bazel) and self.m.path.isfile(workspace):
self.m.step.empty( # pragma: no cover
f'{module_bazel} and {workspace} both exist',
status='FAILURE',
)
if self.m.path.isfile(module_bazel):
return module_bazel # pragma: no cover
if self.m.path.isfile(workspace):
return workspace
self.m.step.empty( # pragma: no cover
'no WORKSPACE or MODULE.bazel file found',
status='FAILURE',
)
def update_git_repository(
self,
checkout: checkout_self.m.CheckoutContext,
git_repository: GitRepository,
) -> list[git_roll_util_api.Roll]:
project_name = git_repository.name
if not project_name:
project_name = git_repository.remote.split('/')[-1]
project_name = project_name.removesuffix('.git')
with self.m.step.nest(project_name):
workspace_path = self.workspace_path(
checkout.root,
git_repository.workspace_path,
)
git_repository.branch = git_repository.branch or 'main'
new_revision = self.m.git_roll_util.resolve_new_revision(
git_repository.remote,
git_repository.branch,
checkout.remotes_equivalent,
)
update_result = self.update_commit_hash(
checkout=checkout,
project_remote=git_repository.remote,
new_revision=new_revision,
path=workspace_path,
delay_write=True,
)
if not update_result:
self.m.step.empty(
'failed to update commit hash',
status='FAILURE',
)
try:
roll = self.m.git_roll_util.get_roll(
repo_url=git_repository.remote,
repo_short_name=project_name,
old_rev=update_result.old_revision,
new_rev=new_revision,
)
except self.m.git_roll_util.BackwardsRollError:
return []
update_result.finalize()
with self.m.step.nest('update MODULE.bazel.lock'):
runner = self.m.bazel.new_runner(
checkout=checkout,
options=BazelOptions(),
)
with self.m.context(cwd=checkout.root):
self.m.step(
'bazelisk mod deps --lockfile_mode=update',
[
runner.ensure(),
'mod',
'deps',
'--lockfile_mode=update',
],
)
return [roll]
def _read(self, path: config_types.Path, num_nearby_lines: int):
lines = [''] * num_nearby_lines
lines.extend(
self.m.file.read_text(
f'read old {path.name}',
path,
test_data=self.m.bazel_roll.test_api.TEST_WORKSPACE_FILE,
)
.strip()
.splitlines()
)
lines.extend([''] * num_nearby_lines)
return lines
def _get_matching_groups(
self,
checkout: checkout_api.CheckoutContext,
lines: Sequence[str],
num_nearby_lines: int,
project_remote: str,
replace_remote: bool = False,
) -> list[tuple[LineProxy, list[LineProxy]]]:
matching_groups: list[tuple[LineProxy, list[LineProxy]]] = []
for nearby_lines in nwise(proxy(lines), num_nearby_lines * 2 + 1):
curr = nearby_lines[len(nearby_lines) // 2]
match = re.search(
r'^\s*remote\s*=\s*"(?P<remote>[^"]+)",?\s*$', curr.text
)
if not match:
continue
match_remote = match.group('remote')
if checkout.remotes_equivalent(match_remote, project_remote):
pres = self.m.step.empty(
f'found equivalent remote {match_remote!r}'
).presentation
pres.logs['lines'] = [repr(x) for x in nearby_lines]
matching_groups.append((curr, nearby_lines))
if replace_remote and match_remote != project_remote:
curr.text = curr.text.replace(match_remote, project_remote)
else:
pres = self.m.step.empty(
f'found other remote {match_remote!r}'
).presentation
pres.logs['lines'] = [repr(x) for x in nearby_lines]
return matching_groups
def _process_nearby_lines(self, nearby_lines):
pres = self.m.step.empty('lines').presentation
center = nearby_lines[len(nearby_lines) // 2].idx
pres.logs['0_center'] = [str(center)]
pres.logs['1_orig'] = [repr(x) for x in nearby_lines]
nearby_lines = trim_nearby_lines(nearby_lines)
pres.logs['2_trimmed'] = [repr(x) for x in nearby_lines]
nearby_lines = proximity_sort_nearby_lines(nearby_lines, center)
pres.logs['3_sorted'] = [repr(x) for x in nearby_lines]
return nearby_lines
def retrieve_git_repository_attributes(
self,
checkout: checkout_api.CheckoutContext,
project_remote: str,
num_nearby_lines: int = 10,
path: config_types.Path | None = None,
) -> list[dict[str, str]]:
if not path:
path = checkout.root / 'WORKSPACE'
lines = self._read(path, num_nearby_lines)
matching_groups = self._get_matching_groups(
checkout=checkout,
lines=lines,
num_nearby_lines=num_nearby_lines,
project_remote=project_remote,
)
results: list[dict[str, str]] = []
for _, nearby_lines in matching_groups:
entry = {}
nearby_lines = self._process_nearby_lines(nearby_lines)
for line in nearby_lines:
if match := re.search(
r'^\s*([\w_]+)\s*=\s*(\S.*),\s*$',
line.text,
):
entry[match.group(1)] = match.group(2).strip('"')
if match.group(1) == 'module_name' and 'name' not in entry:
entry['name'] = match.group(2).strip('"')
if entry:
results.append(entry)
return results
def update_commit_hash(
self,
*,
checkout: checkout_api.CheckoutContext,
project_remote: str,
new_revision: str,
num_nearby_lines: int = 10,
path: config_types.Path | None,
replace_remote: bool = False,
delay_write: bool = False,
) -> UpdateCommitHashResult | None:
if not path:
path = checkout.root / 'WORKSPACE'
lines = self._read(path, num_nearby_lines)
matching_groups = self._get_matching_groups(
checkout=checkout,
lines=lines,
num_nearby_lines=num_nearby_lines,
project_remote=project_remote,
replace_remote=replace_remote,
)
if not matching_groups:
self.m.step.empty(
f'could not find remote {project_remote} in {path}',
)
return None
project_names: list[str] = []
for matching_line, nearby_lines in matching_groups:
nearby_lines = self._process_nearby_lines(nearby_lines)
commit_rx = re.compile(
r'^(?P<prefix>\s*commit\s*=\s*")'
r'(?P<commit>[0-9a-f]{40})'
r'(?P<suffix>",?\s*)$'
)
for line in nearby_lines:
if match := commit_rx.search(line.text):
idx = line.idx
break
else:
self.m.step.empty(
'could not find commit line adjacent to '
f'{matching_line.text!r} in {path}',
)
return None
old_revision = match.group('commit')
prefix = match.group("prefix")
suffix = match.group("suffix")
lines[idx] = f'{prefix}{new_revision}{suffix}'
# Remove all existing metadata lines in this git_repository() entry.
idx2 = idx - 1
while lines[idx2].strip().startswith('# ROLL: '):
lines[idx2] = None
idx2 -= 1
ws_prefix = re.search(r'^\s*', prefix).group(0)
comment_prefix = f'{ws_prefix}# ROLL: '
now = self.m.time.utcnow().strftime('%Y-%m-%d')
comment_lines = (
f'{comment_prefix}Warning: this entry is automatically '
'updated.',
f'{comment_prefix}Last updated {now}.',
f'{comment_prefix}By {self.m.buildbucket.build_url()}.',
)
lines[idx] = '\n'.join(comment_lines + (lines[idx],))
for line in nearby_lines:
if match := re.search(
r'^\s*(?:module_)?name\s*=\s*"(?P<name>[^"]+)",?\s*$',
line.text or '',
):
project_names.append(match.group('name'))
break
def write():
self.m.file.write_text(
f'write new {path.name}',
path,
''.join(
f'{x}\n'
for x in lines[num_nearby_lines:-num_nearby_lines]
if x is not None
),
)
if delay_write:
return UpdateCommitHashResult(
old_revision=old_revision,
project_name=', '.join(project_names),
finalize=write,
)
else:
write()
return UpdateCommitHashResult(
old_revision=old_revision,
project_name=', '.join(project_names),
)