Add support for requirement extras (#28)
diff --git a/extract_wheels/__init__.py b/extract_wheels/__init__.py index ac3d72c..390dcd0 100644 --- a/extract_wheels/__init__.py +++ b/extract_wheels/__init__.py
@@ -11,7 +11,7 @@ import subprocess import sys -from extract_wheels.lib import bazel +from extract_wheels.lib import bazel, requirements def configure_reproducible_wheels() -> None: @@ -70,8 +70,10 @@ [sys.executable, "-m", "pip", "wheel", "-r", args.requirements] ) + extras = requirements.parse_extras(args.requirements) + targets = [ - '"%s%s"' % (args.repo, bazel.extract_wheel(whl, [])) + '"%s%s"' % (args.repo, bazel.extract_wheel(whl, extras)) for whl in glob.glob("*.whl") ]
diff --git a/extract_wheels/lib/BUILD b/extract_wheels/lib/BUILD index 1ff0e45..ef6e0ab 100644 --- a/extract_wheels/lib/BUILD +++ b/extract_wheels/lib/BUILD
@@ -7,6 +7,7 @@ "bazel.py", "namespace_pkgs.py", "purelib.py", + "requirements.py", "wheel.py", ], deps = [ @@ -26,3 +27,15 @@ ":lib", ], ) + +py_test( + name = "requirements_test", + size = "small", + srcs = [ + "requirements_test.py", + ], + tags = ["unit"], + deps = [ + ":lib", + ], +)
diff --git a/extract_wheels/lib/bazel.py b/extract_wheels/lib/bazel.py index ed40038..2d4c61e 100644 --- a/extract_wheels/lib/bazel.py +++ b/extract_wheels/lib/bazel.py
@@ -1,7 +1,7 @@ """Utility functions to manipulate Bazel files""" import os import textwrap -from typing import Iterable, List +from typing import Iterable, List, Dict, Set from extract_wheels.lib import namespace_pkgs, wheel, purelib @@ -113,7 +113,7 @@ namespace_pkgs.add_pkgutil_style_namespace_pkg_init(ns_pkg_dir) -def extract_wheel(wheel_file: str, extras: List[str]) -> str: +def extract_wheel(wheel_file: str, extras: Dict[str, Set[str]]) -> str: """Extracts wheel into given directory and creates a py_library target. Args: @@ -134,16 +134,17 @@ purelib.spread_purelib_into_root(directory) setup_namespace_pkg_compatibility(directory) + extras_requested = extras[whl.name] if whl.name in extras else set() + + sanitised_dependencies = [ + '"//%s"' % sanitise_name(d) for d in sorted(whl.dependencies(extras_requested)) + ] + with open(os.path.join(directory, "BUILD"), "w") as build_file: - build_file.write( - generate_build_file_contents( - sanitise_name(whl.name), - [ - '"//%s"' % sanitise_name(d) - for d in sorted(whl.dependencies(extras_requested=extras)) - ], - ) + contents = generate_build_file_contents( + sanitise_name(whl.name), sanitised_dependencies, ) + build_file.write(contents) os.remove(whl.path)
diff --git a/extract_wheels/lib/namespace_pkgs_test.py b/extract_wheels/lib/namespace_pkgs_test.py index 658625f..d4e2a4c 100644 --- a/extract_wheels/lib/namespace_pkgs_test.py +++ b/extract_wheels/lib/namespace_pkgs_test.py
@@ -1,7 +1,5 @@ -import os import pathlib import shutil -import sys import tempfile import unittest @@ -133,17 +131,5 @@ self.assertEqual(actual, set()) -def main(): - loader = unittest.TestLoader() - cur_dir = os.path.dirname(os.path.realpath(__file__)) - - suite = loader.discover(cur_dir) - - runner = unittest.TextTestRunner() - result = runner.run(suite) - if result.errors or result.failures: - sys.exit(1) - - if __name__ == "__main__": - main() + unittest.main()
diff --git a/extract_wheels/lib/requirements.py b/extract_wheels/lib/requirements.py new file mode 100644 index 0000000..e246379 --- /dev/null +++ b/extract_wheels/lib/requirements.py
@@ -0,0 +1,45 @@ +import re +from typing import Dict, Set, Tuple, Optional + + +def parse_extras(requirements_path: str) -> Dict[str, Set[str]]: + """Parse over the requirements.txt file to find extras requested. + + Args: + requirements_path: The filepath for the requirements.txt file to parse. + + Returns: + A dictionary mapping the requirement name to a set of extras requested. + """ + + extras_requested = {} + with open(requirements_path, "r") as requirements: + # Merge all backslash line continuations so we parse each requirement as a single line. + for line in requirements.read().replace("\\\n", "").split("\n"): + requirement, extras = _parse_requirement_for_extra(line) + if requirement and extras: + extras_requested[requirement] = extras + + return extras_requested + + +def _parse_requirement_for_extra( + requirement: str, +) -> Tuple[Optional[str], Optional[Set[str]]]: + """Given a requirement string, returns the requirement name and set of extras, if extras specified. + Else, returns (None, None) + """ + + # https://www.python.org/dev/peps/pep-0508/#grammar + extras_pattern = re.compile( + r"^\s*([0-9A-Za-z][0-9A-Za-z_.\-]*)\s*\[\s*([0-9A-Za-z][0-9A-Za-z_.\-]*(?:\s*,\s*[0-9A-Za-z][0-9A-Za-z_.\-]*)*)\s*\]" + ) + + matches = extras_pattern.match(requirement) + if matches: + return ( + matches.group(1), + {extra.strip() for extra in matches.group(2).split(",")}, + ) + + return None, None
diff --git a/extract_wheels/lib/requirements_test.py b/extract_wheels/lib/requirements_test.py new file mode 100644 index 0000000..2b96a75 --- /dev/null +++ b/extract_wheels/lib/requirements_test.py
@@ -0,0 +1,32 @@ +import unittest + +from extract_wheels.lib import requirements + + +class TestRequirementExtrasParsing(unittest.TestCase): + def test_parses_requirement_for_extra(self) -> None: + cases = [ + ("name[foo]", ("name", frozenset(["foo"]))), + ("name[ Foo123 ]", ("name", frozenset(["Foo123"]))), + (" name1[ foo ] ", ("name1", frozenset(["foo"]))), + ( + "name [fred,bar] @ http://foo.com ; python_version=='2.7'", + ("name", frozenset(["fred", "bar"])), + ), + ( + "name[quux, strange];python_version<'2.7' and platform_version=='2'", + ("name", frozenset(["quux", "strange"])), + ), + ("name; (os_name=='a' or os_name=='b') and os_name=='c'", (None, None),), + ("name@http://foo.com", (None, None),), + ] + + for case, expected in cases: + with self.subTest(): + self.assertTupleEqual( + requirements._parse_requirement_for_extra(case), expected + ) + + +if __name__ == "__main__": + unittest.main()
diff --git a/extract_wheels/lib/wheel.py b/extract_wheels/lib/wheel.py index 6737ee8..5fe74a1 100644 --- a/extract_wheels/lib/wheel.py +++ b/extract_wheels/lib/wheel.py
@@ -2,7 +2,7 @@ import glob import os import zipfile -from typing import Dict, Optional, List, Set +from typing import Dict, Optional, Set import pkg_resources import pkginfo @@ -26,7 +26,7 @@ def metadata(self) -> pkginfo.Wheel: return pkginfo.get_metadata(self.path) - def dependencies(self, extras_requested: Optional[List[str]] = None) -> Set[str]: + def dependencies(self, extras_requested: Optional[Set[str]] = None) -> Set[str]: dependency_set = set() for wheel_req in self.metadata.requires_dist: