feat: add whl_filegroup (#1904)

This adds a `whl_filegroup` rule which can take a wheel as input and
extract it, making the
output available to consumers. A pattern is allowed to better control
what files
are extracted.

This is handy to, e.g. get the numpy headers from the numpy wheel so
they can be
used in a cc_library. e.g., 

Closes #1903 
Fixes #1311
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 63ece30..cc5ba4d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -80,6 +80,9 @@
 
 [precompile-docs]: /precompiling
 
+* (whl_filegroup) Added a new `whl_filegroup` rule to extract files from a wheel file.
+  This is useful to extract headers for use in a `cc_library`.
+
 ## [0.32.2] - 2024-05-14
 
 [0.32.2]: https://github.com/bazelbuild/rules_python/releases/tag/0.32.2
diff --git a/python/BUILD.bazel b/python/BUILD.bazel
index 5140470..3ab390d 100644
--- a/python/BUILD.bazel
+++ b/python/BUILD.bazel
@@ -101,6 +101,7 @@
         "//python/private:bzlmod_enabled_bzl",
         "//python/private:full_version_bzl",
         "//python/private:render_pkg_aliases_bzl",
+        "//python/private/whl_filegroup:whl_filegroup_bzl",
     ],
 )
 
diff --git a/python/pip.bzl b/python/pip.bzl
index aeedf57..c7cdbb2 100644
--- a/python/pip.bzl
+++ b/python/pip.bzl
@@ -25,10 +25,12 @@
 load("//python/private:full_version.bzl", "full_version")
 load("//python/private:normalize_name.bzl", "normalize_name")
 load("//python/private:render_pkg_aliases.bzl", "NO_MATCH_ERROR_MESSAGE_TEMPLATE")
+load("//python/private/whl_filegroup:whl_filegroup.bzl", _whl_filegroup = "whl_filegroup")
 
 compile_pip_requirements = _compile_pip_requirements
 package_annotation = _package_annotation
 pip_parse = pip_repository
+whl_filegroup = _whl_filegroup
 
 def _multi_pip_parse_impl(rctx):
     rules_python = rctx.attr._rules_python_workspace.workspace_name
diff --git a/python/pip_install/tools/wheel_installer/wheel.py b/python/pip_install/tools/wheel_installer/wheel.py
index d355bfe..b84c214 100644
--- a/python/pip_install/tools/wheel_installer/wheel.py
+++ b/python/pip_install/tools/wheel_installer/wheel.py
@@ -575,7 +575,7 @@
         self._path = path
 
     @property
-    def path(self) -> str:
+    def path(self) -> Path:
         return self._path
 
     @property
diff --git a/python/private/BUILD.bazel b/python/private/BUILD.bazel
index 45f50ef..3e56208 100644
--- a/python/private/BUILD.bazel
+++ b/python/private/BUILD.bazel
@@ -30,6 +30,7 @@
         "//python/private/bzlmod:distribution",
         "//python/private/common:distribution",
         "//python/private/proto:distribution",
+        "//python/private/whl_filegroup:distribution",
         "//tools/build_defs/python/private:distribution",
     ],
     visibility = ["//python:__pkg__"],
diff --git a/python/private/whl_filegroup/BUILD.bazel b/python/private/whl_filegroup/BUILD.bazel
new file mode 100644
index 0000000..398b9af
--- /dev/null
+++ b/python/private/whl_filegroup/BUILD.bazel
@@ -0,0 +1,20 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+load("//python:defs.bzl", "py_binary")
+
+filegroup(
+    name = "distribution",
+    srcs = glob(["**"]),
+    visibility = ["//python/private:__pkg__"],
+)
+
+bzl_library(
+    name = "whl_filegroup_bzl",
+    srcs = ["whl_filegroup.bzl"],
+    visibility = ["//:__subpackages__"],
+)
+
+py_binary(
+    name = "extract_wheel_files",
+    srcs = ["extract_wheel_files.py"],
+    visibility = ["//visibility:public"],
+)
diff --git a/python/private/whl_filegroup/extract_wheel_files.py b/python/private/whl_filegroup/extract_wheel_files.py
new file mode 100644
index 0000000..e81e6a3
--- /dev/null
+++ b/python/private/whl_filegroup/extract_wheel_files.py
@@ -0,0 +1,62 @@
+"""Extract files from a wheel's RECORD."""
+
+import re
+import sys
+import zipfile
+from collections.abc import Iterable
+from pathlib import Path
+
+WhlRecord = dict[str, tuple[str, int]]
+
+
+def get_record(whl_path: Path) -> WhlRecord:
+    try:
+        zipf = zipfile.ZipFile(whl_path)
+    except zipfile.BadZipFile as ex:
+        raise RuntimeError(f"{whl_path} is not a valid zip file") from ex
+    files = zipf.namelist()
+    try:
+        (record_file,) = [name for name in files if name.endswith(".dist-info/RECORD")]
+    except ValueError:
+        raise RuntimeError(f"{whl_path} doesn't contain exactly one .dist-info/RECORD")
+    record_lines = zipf.read(record_file).decode().splitlines()
+    return {
+        file: (filehash, int(filelen))
+        for line in record_lines
+        for file, filehash, filelen in [line.split(",")]
+        if filehash  # Skip RECORD itself, which has no hash or length
+    }
+
+
+def get_files(whl_record: WhlRecord, regex_pattern: str) -> list[str]:
+    """Get files in a wheel that match a regex pattern."""
+    p = re.compile(regex_pattern)
+    return [filepath for filepath in whl_record.keys() if re.match(p, filepath)]
+
+
+def extract_files(whl_path: Path, files: Iterable[str], outdir: Path) -> None:
+    """Extract files from whl_path to outdir."""
+    zipf = zipfile.ZipFile(whl_path)
+    for file in files:
+        zipf.extract(file, outdir)
+
+
+def main() -> None:
+    if len(sys.argv) not in {3, 4}:
+        print(
+            f"Usage: {sys.argv[0]} <wheel> <out_dir> [regex_pattern]",
+            file=sys.stderr,
+        )
+        sys.exit(1)
+
+    whl_path = Path(sys.argv[1]).resolve()
+    outdir = Path(sys.argv[2])
+    regex_pattern = sys.argv[3] if len(sys.argv) == 4 else ""
+
+    whl_record = get_record(whl_path)
+    files = get_files(whl_record, regex_pattern)
+    extract_files(whl_path, files, outdir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/python/private/whl_filegroup/whl_filegroup.bzl b/python/private/whl_filegroup/whl_filegroup.bzl
new file mode 100644
index 0000000..c5f97e6
--- /dev/null
+++ b/python/private/whl_filegroup/whl_filegroup.bzl
@@ -0,0 +1,57 @@
+"""Implementation of whl_filegroup rule."""
+
+def _whl_filegroup_impl(ctx):
+    out_dir = ctx.actions.declare_directory(ctx.attr.name)
+    ctx.actions.run(
+        outputs = [out_dir],
+        inputs = [ctx.file.whl],
+        arguments = [
+            ctx.file.whl.path,
+            out_dir.path,
+            ctx.attr.pattern,
+        ],
+        executable = ctx.executable._extract_wheel_files_tool,
+        mnemonic = "PyExtractWheelFiles",
+        progress_message = "Extracting %s files from %s" % (ctx.attr.pattern, ctx.file.whl.short_path),
+    )
+    return [DefaultInfo(
+        files = depset([out_dir]),
+        runfiles = ctx.runfiles(files = [out_dir] if ctx.attr.runfiles else []),
+    )]
+
+whl_filegroup = rule(
+    _whl_filegroup_impl,
+    doc = """Extract files matching a regular expression from a wheel file.
+
+An empty pattern will match all files.
+
+Example usage:
+```starlark
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_python//python:pip.bzl", "whl_filegroup")
+
+whl_filegroup(
+    name = "numpy_includes",
+    pattern = "numpy/core/include/numpy",
+    whl = "@pypi//numpy:whl",
+)
+
+cc_library(
+    name = "numpy_headers",
+    hdrs = [":numpy_includes"],
+    includes = ["numpy_includes/numpy/core/include"],
+    deps = ["@rules_python//python/cc:current_py_cc_headers"],
+)
+```
+""",
+    attrs = {
+        "pattern": attr.string(default = "", doc = "Only file paths matching this regex pattern will be extracted."),
+        "runfiles": attr.bool(default = False, doc = "Whether to include the output TreeArtifact in this target's runfiles."),
+        "whl": attr.label(mandatory = True, allow_single_file = True, doc = "The wheel to extract files from."),
+        "_extract_wheel_files_tool": attr.label(
+            default = Label("//python/private/whl_filegroup:extract_wheel_files"),
+            cfg = "exec",
+            executable = True,
+        ),
+    },
+)
diff --git a/tests/whl_filegroup/BUILD.bazel b/tests/whl_filegroup/BUILD.bazel
new file mode 100644
index 0000000..d8b711d
--- /dev/null
+++ b/tests/whl_filegroup/BUILD.bazel
@@ -0,0 +1,69 @@
+load("@bazel_skylib//rules:write_file.bzl", "write_file")
+load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
+load("//python:defs.bzl", "py_library", "py_test")
+load("//python:packaging.bzl", "py_package", "py_wheel")
+load("//python:pip.bzl", "whl_filegroup")
+load(":whl_filegroup_tests.bzl", "whl_filegroup_test_suite")
+
+whl_filegroup_test_suite(name = "whl_filegroup_tests")
+
+py_test(
+    name = "extract_wheel_files_test",
+    size = "small",
+    srcs = ["extract_wheel_files_test.py"],
+    data = ["//examples/wheel:minimal_with_py_package"],
+    deps = ["//python/private/whl_filegroup:extract_wheel_files"],
+)
+
+write_file(
+    name = "header",
+    out = "include/whl_headers/header.h",
+    content = [
+        "#pragma once",
+        "#include <Python.h>",
+        "#define CUSTOM_ZERO ((Py_ssize_t) 0)",
+    ],
+)
+
+write_file(
+    name = "lib_py",
+    out = "lib.py",
+)
+
+py_library(
+    name = "lib",
+    srcs = ["lib.py"],
+    data = [":header"],
+)
+
+py_package(
+    name = "pkg",
+    deps = [":lib"],
+)
+
+py_wheel(
+    name = "wheel",
+    distribution = "wheel",
+    python_tag = "py3",
+    version = "0.0.1",
+    deps = [":pkg"],
+)
+
+whl_filegroup(
+    name = "filegroup",
+    pattern = "tests/whl_filegroup/include/.*\\.h",
+    whl = ":wheel",
+)
+
+cc_library(
+    name = "whl_headers",
+    hdrs = [":filegroup"],
+    includes = ["filegroup/tests/whl_filegroup/include"],
+    deps = ["@rules_python//python/cc:current_py_cc_headers"],
+)
+
+cc_test(
+    name = "whl_headers_test",
+    srcs = ["whl_headers_test.c"],
+    deps = [":whl_headers"],
+)
diff --git a/tests/whl_filegroup/extract_wheel_files_test.py b/tests/whl_filegroup/extract_wheel_files_test.py
new file mode 100644
index 0000000..2ea175b
--- /dev/null
+++ b/tests/whl_filegroup/extract_wheel_files_test.py
@@ -0,0 +1,63 @@
+import tempfile
+import unittest
+from pathlib import Path
+
+from python.private.whl_filegroup import extract_wheel_files
+
+_WHEEL = Path("examples/wheel/example_minimal_package-0.0.1-py3-none-any.whl")
+
+
+class WheelRecordTest(unittest.TestCase):
+    def test_get_wheel_record(self) -> None:
+        record = extract_wheel_files.get_record(_WHEEL)
+        expected = {
+            "examples/wheel/lib/data.txt": (
+                "sha256=9vJKEdfLu8bZRArKLroPZJh1XKkK3qFMXiM79MBL2Sg",
+                12,
+            ),
+            "examples/wheel/lib/module_with_data.py": (
+                "sha256=8s0Khhcqz3yVsBKv2IB5u4l4TMKh7-c_V6p65WVHPms",
+                637,
+            ),
+            "examples/wheel/lib/simple_module.py": (
+                "sha256=z2hwciab_XPNIBNH8B1Q5fYgnJvQTeYf0ZQJpY8yLLY",
+                637,
+            ),
+            "examples/wheel/main.py": (
+                "sha256=sgg5iWN_9inYBjm6_Zw27hYdmo-l24fA-2rfphT-IlY",
+                909,
+            ),
+            "example_minimal_package-0.0.1.dist-info/WHEEL": (
+                "sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us",
+                91,
+            ),
+            "example_minimal_package-0.0.1.dist-info/METADATA": (
+                "sha256=cfiQ2hFJhCKCUgbwtAwWG0fhW6NTzw4cr1uKOBcV_IM",
+                76,
+            ),
+        }
+        self.maxDiff = None
+        self.assertDictEqual(record, expected)
+
+    def test_get_files(self) -> None:
+        pattern = "(examples/wheel/lib/.*\.txt$|.*main)"
+        record = extract_wheel_files.get_record(_WHEEL)
+        files = extract_wheel_files.get_files(record, pattern)
+        expected = ["examples/wheel/lib/data.txt", "examples/wheel/main.py"]
+        self.assertEqual(files, expected)
+
+    def test_extract(self) -> None:
+        files = {"examples/wheel/lib/data.txt", "examples/wheel/main.py"}
+        with tempfile.TemporaryDirectory() as tmpdir:
+            outdir = Path(tmpdir)
+            extract_wheel_files.extract_files(_WHEEL, files, outdir)
+            extracted_files = {
+                f.relative_to(outdir).as_posix()
+                for f in outdir.glob("**/*")
+                if f.is_file()
+            }
+        self.assertEqual(extracted_files, files)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/whl_filegroup/whl_filegroup_tests.bzl b/tests/whl_filegroup/whl_filegroup_tests.bzl
new file mode 100644
index 0000000..acb9341
--- /dev/null
+++ b/tests/whl_filegroup/whl_filegroup_tests.bzl
@@ -0,0 +1,34 @@
+"""Test for py_wheel."""
+
+load("@rules_testing//lib:analysis_test.bzl", "analysis_test", "test_suite")
+load("@rules_testing//lib:util.bzl", "util")
+load("//python:pip.bzl", "whl_filegroup")
+
+def _test_runfiles(name):
+    for runfiles in [True, False]:
+        util.helper_target(
+            whl_filegroup,
+            name = name + "_subject_runfiles_{}".format(runfiles),
+            whl = ":wheel",
+            runfiles = runfiles,
+        )
+    analysis_test(
+        name = name,
+        impl = _test_runfiles_impl,
+        targets = {
+            "no_runfiles": name + "_subject_runfiles_False",
+            "with_runfiles": name + "_subject_runfiles_True",
+        },
+    )
+
+def _test_runfiles_impl(env, targets):
+    env.expect.that_target(targets.with_runfiles).runfiles().contains_exactly([env.ctx.workspace_name + "/{package}/{name}"])
+    env.expect.that_target(targets.no_runfiles).runfiles().contains_exactly([])
+
+def whl_filegroup_test_suite(name):
+    """Create the test suite.
+
+    Args:
+        name: the name of the test suite
+    """
+    test_suite(name = name, tests = [_test_runfiles])
diff --git a/tests/whl_filegroup/whl_headers_test.c b/tests/whl_filegroup/whl_headers_test.c
new file mode 100644
index 0000000..786395a
--- /dev/null
+++ b/tests/whl_filegroup/whl_headers_test.c
@@ -0,0 +1,5 @@
+#include <whl_headers/header.h>
+
+int main(int argc, char**argv) {
+    return CUSTOM_ZERO;
+}