refactor(gazelle): Generate a modules map per wheel, then merge (#3415)

This change internally splits modules mapping generation to be
per-wheel, with a final quick "merge" action at the end.

The idea is to make this process both concurrent and cached (courtesy of
Bazel), which can be ideal for codebases with a large set of
requirements (as many monorepos end up doing)

Note that the `generator.py` interface changed. This seemed internal, so
I didn't mark it breaking (but this change could actually just leave the
generator alone, since the current implementation is fine with 1 wheel).

I ran this on the work repo and saw no change in output (but as I edited
a single requirement, the overall process was fast ⚡ )
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5a71e5f..7d3723c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -62,6 +62,7 @@
 {#v0-0-0-changed}
 ### Changed
 * (toolchains) Use toolchains from the [20251031] release.
+* (gazelle) Internally split modules mapping generation to be per-wheel for concurrency and caching.
 
 {#v0-0-0-fixed}
 ### Fixed
diff --git a/gazelle/modules_mapping/BUILD.bazel b/gazelle/modules_mapping/BUILD.bazel
index 3a9a8a4..3423f34 100644
--- a/gazelle/modules_mapping/BUILD.bazel
+++ b/gazelle/modules_mapping/BUILD.bazel
@@ -9,6 +9,12 @@
     visibility = ["//visibility:public"],
 )
 
+py_binary(
+    name = "merger",
+    srcs = ["merger.py"],
+    visibility = ["//visibility:public"],
+)
+
 copy_file(
     name = "pytest_wheel",
     src = "@pytest//file",
@@ -33,6 +39,18 @@
     deps = [":generator"],
 )
 
+py_test(
+    name = "test_merger",
+    srcs = ["test_merger.py"],
+    data = [
+        "django_types_wheel",
+        "pytest_wheel",
+    ],
+    imports = ["."],
+    main = "test_merger.py",
+    deps = [":merger"],
+)
+
 filegroup(
     name = "distribution",
     srcs = glob(["**"]),
diff --git a/gazelle/modules_mapping/def.bzl b/gazelle/modules_mapping/def.bzl
index 48a5477..74d3c9e 100644
--- a/gazelle/modules_mapping/def.bzl
+++ b/gazelle/modules_mapping/def.bzl
@@ -30,25 +30,39 @@
         transitive = [dep[DefaultInfo].files for dep in ctx.attr.wheels] + [dep[DefaultInfo].data_runfiles.files for dep in ctx.attr.wheels],
     )
 
-    args = ctx.actions.args()
+    # Run the generator once per-wheel (to leverage caching)
+    per_wheel_outputs = []
+    for idx, whl in enumerate(all_wheels.to_list()):
+        wheel_modules_mapping = ctx.actions.declare_file("{}.{}".format(modules_mapping.short_path, idx))
+        args = ctx.actions.args()
+        args.add("--output_file", wheel_modules_mapping.path)
+        if ctx.attr.include_stub_packages:
+            args.add("--include_stub_packages")
+        args.add_all("--exclude_patterns", ctx.attr.exclude_patterns)
+        args.add("--wheel", whl.path)
 
-    # Spill parameters to a file prefixed with '@'. Note, the '@' prefix is the same
-    # prefix as used in the `generator.py` in `fromfile_prefix_chars` attribute.
-    args.use_param_file(param_file_arg = "@%s")
-    args.set_param_file_format(format = "multiline")
-    if ctx.attr.include_stub_packages:
-        args.add("--include_stub_packages")
-    args.add("--output_file", modules_mapping)
-    args.add_all("--exclude_patterns", ctx.attr.exclude_patterns)
-    args.add_all("--wheels", all_wheels)
+        ctx.actions.run(
+            inputs = [whl],
+            outputs = [wheel_modules_mapping],
+            executable = ctx.executable._generator,
+            arguments = [args],
+            use_default_shell_env = False,
+        )
+        per_wheel_outputs.append(wheel_modules_mapping)
+
+    # Then merge the individual JSONs together
+    merge_args = ctx.actions.args()
+    merge_args.add("--output", modules_mapping.path)
+    merge_args.add_all("--inputs", [f.path for f in per_wheel_outputs])
 
     ctx.actions.run(
-        inputs = all_wheels,
+        inputs = per_wheel_outputs,
         outputs = [modules_mapping],
-        executable = ctx.executable._generator,
-        arguments = [args],
+        executable = ctx.executable._merger,
+        arguments = [merge_args],
         use_default_shell_env = False,
     )
+
     return [DefaultInfo(files = depset([modules_mapping]))]
 
 modules_mapping = rule(
@@ -79,6 +93,11 @@
             default = "//modules_mapping:generator",
             executable = True,
         ),
+        "_merger": attr.label(
+            cfg = "exec",
+            default = "//modules_mapping:merger",
+            executable = True,
+        ),
     },
     doc = "Creates a modules_mapping.json file for mapping module names to wheel distribution names.",
 )
diff --git a/gazelle/modules_mapping/generator.py b/gazelle/modules_mapping/generator.py
index ea11f3e..611910c 100644
--- a/gazelle/modules_mapping/generator.py
+++ b/gazelle/modules_mapping/generator.py
@@ -96,8 +96,7 @@
                 ext = "".join(pathlib.Path(root).suffixes)
             module = root[: -len(ext)].replace("/", ".")
             if not self.is_excluded(module):
-                if not self.is_excluded(module):
-                    self.mapping[module] = wheel_name
+                self.mapping[module] = wheel_name
 
     def is_excluded(self, module):
         for pattern in self.excluded_patterns:
@@ -105,14 +104,20 @@
                 return True
         return False
 
-    # run is the entrypoint for the generator.
-    def run(self, wheels):
-        for whl in wheels:
-            try:
-                self.dig_wheel(whl)
-            except AssertionError as error:
-                print(error, file=self.stderr)
-                return 1
+    def run(self, wheel: pathlib.Path) -> int:
+        """
+        Entrypoint for the generator.
+
+        Args:
+            wheel: The path to the wheel file (`.whl`)
+        Returns:
+            Exit code (for `sys.exit`)
+        """
+        try:
+            self.dig_wheel(wheel)
+        except AssertionError as error:
+            print(error, file=self.stderr)
+            return 1
         self.simplify()
         mapping_json = json.dumps(self.mapping)
         with open(self.output_file, "w") as f:
@@ -152,16 +157,13 @@
     parser = argparse.ArgumentParser(
         prog="generator",
         description="Generates the modules mapping used by the Gazelle manifest.",
-        # Automatically read parameters from a file. Note, the '@' is the same prefix
-        # as set in the 'args.use_param_file' in the bazel rule.
-        fromfile_prefix_chars="@",
     )
     parser.add_argument("--output_file", type=str)
     parser.add_argument("--include_stub_packages", action="store_true")
     parser.add_argument("--exclude_patterns", nargs="+", default=[])
-    parser.add_argument("--wheels", nargs="+", default=[])
+    parser.add_argument("--wheel", type=pathlib.Path)
     args = parser.parse_args()
     generator = Generator(
         sys.stderr, args.output_file, args.exclude_patterns, args.include_stub_packages
     )
-    sys.exit(generator.run(args.wheels))
+    sys.exit(generator.run(args.wheel))
diff --git a/gazelle/modules_mapping/merger.py b/gazelle/modules_mapping/merger.py
new file mode 100644
index 0000000..deb0cb2
--- /dev/null
+++ b/gazelle/modules_mapping/merger.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+"""Merges multiple modules_mapping.json files into a single file."""
+
+import argparse
+import json
+from pathlib import Path
+
+
+def merge_modules_mappings(input_files: list[Path], output_file: Path) -> None:
+    """Merge multiple modules_mapping.json files into one.
+
+    Args:
+        input_files: List of paths to input JSON files to merge
+        output_file: Path where the merged output should be written
+    """
+    merged_mapping = {}
+    for input_file in input_files:
+        mapping = json.loads(input_file.read_text())
+        # Merge the mappings, with later files overwriting earlier ones
+        # if there are conflicts
+        merged_mapping.update(mapping)
+
+    output_file.write_text(json.dumps(merged_mapping))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Merge multiple modules_mapping.json files"
+    )
+    parser.add_argument(
+        "--output",
+        required=True,
+        type=Path,
+        help="Output file path for merged mapping",
+    )
+    parser.add_argument(
+        "--inputs",
+        required=True,
+        nargs="+",
+        type=Path,
+        help="Input JSON files to merge",
+    )
+
+    args = parser.parse_args()
+    merge_modules_mappings(args.inputs, args.output)
diff --git a/gazelle/modules_mapping/test_merger.py b/gazelle/modules_mapping/test_merger.py
new file mode 100644
index 0000000..6260fdd
--- /dev/null
+++ b/gazelle/modules_mapping/test_merger.py
@@ -0,0 +1,61 @@
+import pathlib
+import unittest
+import json
+import tempfile
+
+from merger import merge_modules_mappings
+
+
+class MergerTest(unittest.TestCase):
+    _tmpdir: tempfile.TemporaryDirectory
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._tmpdir = tempfile.TemporaryDirectory()
+
+    def tearDown(self) -> None:
+        super().tearDown()
+        self._tmpdir.cleanup()
+        del self._tmpdir
+
+    @property
+    def tmppath(self) -> pathlib.Path:
+        return pathlib.Path(self._tmpdir.name)
+
+    def make_input(self, mapping: dict[str, str]) -> pathlib.Path:
+        _fd, file = tempfile.mkstemp(suffix=".json", dir=self._tmpdir.name)
+        path = pathlib.Path(file)
+        path.write_text(json.dumps(mapping))
+        return path
+
+    def test_merger(self):
+        output_path = self.tmppath / "output.json"
+        merge_modules_mappings(
+            [
+                self.make_input(
+                    {
+                        "_pytest": "pytest",
+                        "_pytest.__init__": "pytest",
+                        "_pytest._argcomplete": "pytest",
+                        "_pytest.config.argparsing": "pytest",
+                    }
+                ),
+                self.make_input({"django_types": "django_types"}),
+            ],
+            output_path,
+        )
+
+        self.assertEqual(
+            {
+                "_pytest": "pytest",
+                "_pytest.__init__": "pytest",
+                "_pytest._argcomplete": "pytest",
+                "_pytest.config.argparsing": "pytest",
+                "django_types": "django_types",
+            },
+            json.loads(output_path.read_text()),
+        )
+
+
+if __name__ == "__main__":
+    unittest.main()