refactor(gazelle): Generate a modules map per wheel, then merge (#3415)
This change internally splits modules mapping generation to be
per-wheel, with a final quick "merge" action at the end.
The idea is to make this process both concurrent and cached (courtesy of
Bazel), which can be ideal for codebases with a large set of
requirements (as many monorepos end up doing)
Note that the `generator.py` interface changed. This seemed internal, so
I didn't mark it breaking (but this change could actually just leave the
generator alone, since the current implementation is fine with 1 wheel).
I ran this on the work repo and saw no change in output (but as I edited
a single requirement, the overall process was fast ⚡ )
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5a71e5f..7d3723c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -62,6 +62,7 @@
{#v0-0-0-changed}
### Changed
* (toolchains) Use toolchains from the [20251031] release.
+* (gazelle) Internally split modules mapping generation to be per-wheel for concurrency and caching.
{#v0-0-0-fixed}
### Fixed
diff --git a/gazelle/modules_mapping/BUILD.bazel b/gazelle/modules_mapping/BUILD.bazel
index 3a9a8a4..3423f34 100644
--- a/gazelle/modules_mapping/BUILD.bazel
+++ b/gazelle/modules_mapping/BUILD.bazel
@@ -9,6 +9,12 @@
visibility = ["//visibility:public"],
)
+py_binary(
+ name = "merger",
+ srcs = ["merger.py"],
+ visibility = ["//visibility:public"],
+)
+
copy_file(
name = "pytest_wheel",
src = "@pytest//file",
@@ -33,6 +39,18 @@
deps = [":generator"],
)
+py_test(
+ name = "test_merger",
+ srcs = ["test_merger.py"],
+ data = [
+ "django_types_wheel",
+ "pytest_wheel",
+ ],
+ imports = ["."],
+ main = "test_merger.py",
+ deps = [":merger"],
+)
+
filegroup(
name = "distribution",
srcs = glob(["**"]),
diff --git a/gazelle/modules_mapping/def.bzl b/gazelle/modules_mapping/def.bzl
index 48a5477..74d3c9e 100644
--- a/gazelle/modules_mapping/def.bzl
+++ b/gazelle/modules_mapping/def.bzl
@@ -30,25 +30,39 @@
transitive = [dep[DefaultInfo].files for dep in ctx.attr.wheels] + [dep[DefaultInfo].data_runfiles.files for dep in ctx.attr.wheels],
)
- args = ctx.actions.args()
+ # Run the generator once per-wheel (to leverage caching)
+ per_wheel_outputs = []
+ for idx, whl in enumerate(all_wheels.to_list()):
+ wheel_modules_mapping = ctx.actions.declare_file("{}.{}".format(modules_mapping.short_path, idx))
+ args = ctx.actions.args()
+ args.add("--output_file", wheel_modules_mapping.path)
+ if ctx.attr.include_stub_packages:
+ args.add("--include_stub_packages")
+ args.add_all("--exclude_patterns", ctx.attr.exclude_patterns)
+ args.add("--wheel", whl.path)
- # Spill parameters to a file prefixed with '@'. Note, the '@' prefix is the same
- # prefix as used in the `generator.py` in `fromfile_prefix_chars` attribute.
- args.use_param_file(param_file_arg = "@%s")
- args.set_param_file_format(format = "multiline")
- if ctx.attr.include_stub_packages:
- args.add("--include_stub_packages")
- args.add("--output_file", modules_mapping)
- args.add_all("--exclude_patterns", ctx.attr.exclude_patterns)
- args.add_all("--wheels", all_wheels)
+ ctx.actions.run(
+ inputs = [whl],
+ outputs = [wheel_modules_mapping],
+ executable = ctx.executable._generator,
+ arguments = [args],
+ use_default_shell_env = False,
+ )
+ per_wheel_outputs.append(wheel_modules_mapping)
+
+ # Then merge the individual JSONs together
+ merge_args = ctx.actions.args()
+ merge_args.add("--output", modules_mapping.path)
+ merge_args.add_all("--inputs", [f.path for f in per_wheel_outputs])
ctx.actions.run(
- inputs = all_wheels,
+ inputs = per_wheel_outputs,
outputs = [modules_mapping],
- executable = ctx.executable._generator,
- arguments = [args],
+ executable = ctx.executable._merger,
+ arguments = [merge_args],
use_default_shell_env = False,
)
+
return [DefaultInfo(files = depset([modules_mapping]))]
modules_mapping = rule(
@@ -79,6 +93,11 @@
default = "//modules_mapping:generator",
executable = True,
),
+ "_merger": attr.label(
+ cfg = "exec",
+ default = "//modules_mapping:merger",
+ executable = True,
+ ),
},
doc = "Creates a modules_mapping.json file for mapping module names to wheel distribution names.",
)
diff --git a/gazelle/modules_mapping/generator.py b/gazelle/modules_mapping/generator.py
index ea11f3e..611910c 100644
--- a/gazelle/modules_mapping/generator.py
+++ b/gazelle/modules_mapping/generator.py
@@ -96,8 +96,7 @@
ext = "".join(pathlib.Path(root).suffixes)
module = root[: -len(ext)].replace("/", ".")
if not self.is_excluded(module):
- if not self.is_excluded(module):
- self.mapping[module] = wheel_name
+ self.mapping[module] = wheel_name
def is_excluded(self, module):
for pattern in self.excluded_patterns:
@@ -105,14 +104,20 @@
return True
return False
- # run is the entrypoint for the generator.
- def run(self, wheels):
- for whl in wheels:
- try:
- self.dig_wheel(whl)
- except AssertionError as error:
- print(error, file=self.stderr)
- return 1
+ def run(self, wheel: pathlib.Path) -> int:
+ """
+ Entrypoint for the generator.
+
+ Args:
+ wheel: The path to the wheel file (`.whl`)
+ Returns:
+ Exit code (for `sys.exit`)
+ """
+ try:
+ self.dig_wheel(wheel)
+ except AssertionError as error:
+ print(error, file=self.stderr)
+ return 1
self.simplify()
mapping_json = json.dumps(self.mapping)
with open(self.output_file, "w") as f:
@@ -152,16 +157,13 @@
parser = argparse.ArgumentParser(
prog="generator",
description="Generates the modules mapping used by the Gazelle manifest.",
- # Automatically read parameters from a file. Note, the '@' is the same prefix
- # as set in the 'args.use_param_file' in the bazel rule.
- fromfile_prefix_chars="@",
)
parser.add_argument("--output_file", type=str)
parser.add_argument("--include_stub_packages", action="store_true")
parser.add_argument("--exclude_patterns", nargs="+", default=[])
- parser.add_argument("--wheels", nargs="+", default=[])
+ parser.add_argument("--wheel", type=pathlib.Path)
args = parser.parse_args()
generator = Generator(
sys.stderr, args.output_file, args.exclude_patterns, args.include_stub_packages
)
- sys.exit(generator.run(args.wheels))
+ sys.exit(generator.run(args.wheel))
diff --git a/gazelle/modules_mapping/merger.py b/gazelle/modules_mapping/merger.py
new file mode 100644
index 0000000..deb0cb2
--- /dev/null
+++ b/gazelle/modules_mapping/merger.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+"""Merges multiple modules_mapping.json files into a single file."""
+
+import argparse
+import json
+from pathlib import Path
+
+
+def merge_modules_mappings(input_files: list[Path], output_file: Path) -> None:
+ """Merge multiple modules_mapping.json files into one.
+
+ Args:
+ input_files: List of paths to input JSON files to merge
+ output_file: Path where the merged output should be written
+ """
+ merged_mapping = {}
+ for input_file in input_files:
+ mapping = json.loads(input_file.read_text())
+ # Merge the mappings, with later files overwriting earlier ones
+ # if there are conflicts
+ merged_mapping.update(mapping)
+
+ output_file.write_text(json.dumps(merged_mapping))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Merge multiple modules_mapping.json files"
+ )
+ parser.add_argument(
+ "--output",
+ required=True,
+ type=Path,
+ help="Output file path for merged mapping",
+ )
+ parser.add_argument(
+ "--inputs",
+ required=True,
+ nargs="+",
+ type=Path,
+ help="Input JSON files to merge",
+ )
+
+ args = parser.parse_args()
+ merge_modules_mappings(args.inputs, args.output)
diff --git a/gazelle/modules_mapping/test_merger.py b/gazelle/modules_mapping/test_merger.py
new file mode 100644
index 0000000..6260fdd
--- /dev/null
+++ b/gazelle/modules_mapping/test_merger.py
@@ -0,0 +1,61 @@
+import pathlib
+import unittest
+import json
+import tempfile
+
+from merger import merge_modules_mappings
+
+
+class MergerTest(unittest.TestCase):
+ _tmpdir: tempfile.TemporaryDirectory
+
+ def setUp(self) -> None:
+ super().setUp()
+ self._tmpdir = tempfile.TemporaryDirectory()
+
+ def tearDown(self) -> None:
+ super().tearDown()
+ self._tmpdir.cleanup()
+ del self._tmpdir
+
+ @property
+ def tmppath(self) -> pathlib.Path:
+ return pathlib.Path(self._tmpdir.name)
+
+ def make_input(self, mapping: dict[str, str]) -> pathlib.Path:
+ _fd, file = tempfile.mkstemp(suffix=".json", dir=self._tmpdir.name)
+ path = pathlib.Path(file)
+ path.write_text(json.dumps(mapping))
+ return path
+
+ def test_merger(self):
+ output_path = self.tmppath / "output.json"
+ merge_modules_mappings(
+ [
+ self.make_input(
+ {
+ "_pytest": "pytest",
+ "_pytest.__init__": "pytest",
+ "_pytest._argcomplete": "pytest",
+ "_pytest.config.argparsing": "pytest",
+ }
+ ),
+ self.make_input({"django_types": "django_types"}),
+ ],
+ output_path,
+ )
+
+ self.assertEqual(
+ {
+ "_pytest": "pytest",
+ "_pytest.__init__": "pytest",
+ "_pytest._argcomplete": "pytest",
+ "_pytest.config.argparsing": "pytest",
+ "django_types": "django_types",
+ },
+ json.loads(output_path.read_text()),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()