fix(pypi): handle more URL patterns for requirement sources (#2843)
Summary:
- Better handle git references for sdists.
- Better handle direct whl references.
- Add an extra test that turned out to be not needed in the end, but I
left it to increase the code coverage.
Work towards #2363
Fixes #2828
diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl
index 5633328..1583c89 100644
--- a/python/private/pypi/parse_requirements.bzl
+++ b/python/private/pypi/parse_requirements.bzl
@@ -285,12 +285,17 @@
if requirement.srcs.url:
url = requirement.srcs.url
_, _, filename = url.rpartition("/")
+ filename, _, _ = filename.partition("#sha256=")
if "." not in filename:
# detected filename has no extension, it might be an sdist ref
# TODO @aignas 2025-04-03: should be handled if the following is fixed:
# https://github.com/bazel-contrib/rules_python/issues/2363
return [], None
+ if "@" in filename:
+ # this is most likely foo.git@git_sha, skip special handling of these
+ return [], None
+
direct_url_dist = struct(
url = url,
filename = filename,
diff --git a/tests/pypi/index_sources/index_sources_tests.bzl b/tests/pypi/index_sources/index_sources_tests.bzl
index ffeed87..9d12bc6 100644
--- a/tests/pypi/index_sources/index_sources_tests.bzl
+++ b/tests/pypi/index_sources/index_sources_tests.bzl
@@ -21,38 +21,50 @@
def _test_no_simple_api_sources(env):
inputs = {
+ "foo @ git+https://github.com/org/foo.git@deadbeef": struct(
+ requirement = "foo @ git+https://github.com/org/foo.git@deadbeef",
+ marker = "",
+ url = "git+https://github.com/org/foo.git@deadbeef",
+ shas = [],
+ version = "",
+ ),
"foo==0.0.1": struct(
requirement = "foo==0.0.1",
marker = "",
url = "",
+ version = "0.0.1",
),
"foo==0.0.1 @ https://someurl.org": struct(
requirement = "foo==0.0.1 @ https://someurl.org",
marker = "",
url = "https://someurl.org",
+ version = "0.0.1",
),
"foo==0.0.1 @ https://someurl.org/package.whl": struct(
requirement = "foo==0.0.1 @ https://someurl.org/package.whl",
marker = "",
url = "https://someurl.org/package.whl",
+ version = "0.0.1",
),
"foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef": struct(
requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef",
marker = "",
url = "https://someurl.org/package.whl",
shas = ["deadbeef"],
+ version = "0.0.1",
),
"foo==0.0.1 @ https://someurl.org/package.whl; python_version < \"2.7\"\\ --hash=sha256:deadbeef": struct(
requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef",
marker = "python_version < \"2.7\"",
url = "https://someurl.org/package.whl",
shas = ["deadbeef"],
+ version = "0.0.1",
),
}
for input, want in inputs.items():
got = index_sources(input)
env.expect.that_collection(got.shas).contains_exactly(want.shas if hasattr(want, "shas") else [])
- env.expect.that_str(got.version).equals("0.0.1")
+ env.expect.that_str(got.version).equals(want.version)
env.expect.that_str(got.requirement).equals(want.requirement)
env.expect.that_str(got.requirement_line).equals(got.requirement)
env.expect.that_str(got.marker).equals(want.marker)
diff --git a/tests/pypi/parse_requirements/parse_requirements_tests.bzl b/tests/pypi/parse_requirements/parse_requirements_tests.bzl
index 723bb60..c5b2487 100644
--- a/tests/pypi/parse_requirements/parse_requirements_tests.bzl
+++ b/tests/pypi/parse_requirements/parse_requirements_tests.bzl
@@ -30,6 +30,7 @@
bar @ https://example.org/bar-1.0.whl --hash=sha256:deadbeef
baz @ https://test.com/baz-2.0.whl; python_version < "3.8" --hash=sha256:deadb00f
qux @ https://example.org/qux-1.0.tar.gz --hash=sha256:deadbe0f
+torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc
""",
"requirements_extra_args": """\
--index-url=example.org
@@ -37,6 +38,9 @@
foo[extra]==0.0.1 \
--hash=sha256:deadbeef
""",
+ "requirements_git": """
+foo @ git+https://github.com/org/foo.git@deadbeef
+""",
"requirements_linux": """\
foo==0.0.3 --hash=sha256:deadbaaf
""",
@@ -232,6 +236,31 @@
whls = [],
),
],
+ "torch": [
+ struct(
+ distribution = "torch",
+ extra_pip_args = [],
+ is_exposed = True,
+ sdist = None,
+ srcs = struct(
+ marker = "",
+ requirement = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
+ requirement_line = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
+ shas = [],
+ url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
+ version = "",
+ ),
+ target_platforms = ["linux_x86_64"],
+ whls = [
+ struct(
+ filename = "torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl",
+ sha256 = "",
+ url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
+ yanked = False,
+ ),
+ ],
+ ),
+ ],
})
_tests.append(_test_direct_urls)
@@ -623,6 +652,36 @@
_tests.append(_test_optional_hash)
+def _test_git_sources(env):
+ got = parse_requirements(
+ ctx = _mock_ctx(),
+ requirements_by_platform = {
+ "requirements_git": ["linux_x86_64"],
+ },
+ )
+ env.expect.that_dict(got).contains_exactly({
+ "foo": [
+ struct(
+ distribution = "foo",
+ extra_pip_args = [],
+ is_exposed = True,
+ sdist = None,
+ srcs = struct(
+ marker = "",
+ requirement = "foo @ git+https://github.com/org/foo.git@deadbeef",
+ requirement_line = "foo @ git+https://github.com/org/foo.git@deadbeef",
+ shas = [],
+ url = "git+https://github.com/org/foo.git@deadbeef",
+ version = "",
+ ),
+ target_platforms = ["linux_x86_64"],
+ whls = [],
+ ),
+ ],
+ })
+
+_tests.append(_test_git_sources)
+
def parse_requirements_test_suite(name):
"""Create the test suite.