refactor(pypi): split out more utils and start passing 'abi_os_arch' around (#2069)

This is extra preparation needed for #2059.

Summary:
- Create `pypi_repo_utils` for more ergonomic handling of Python in repo
context.
- Split the resolution of requirements files to platforms into a
separate function
to make the testing easier. This also allows more validation that was
realized
  that there is a need for in the WIP feature PR.
- Make the code more robust about the assumption of the target platform
label.

Work towards #260, #1105, #1868.
diff --git a/python/private/pypi/BUILD.bazel b/python/private/pypi/BUILD.bazel
index 08fb725..00602b2 100644
--- a/python/private/pypi/BUILD.bazel
+++ b/python/private/pypi/BUILD.bazel
@@ -161,6 +161,8 @@
     deps = [
         ":index_sources_bzl",
         ":parse_requirements_txt_bzl",
+        ":pypi_repo_utils_bzl",
+        ":requirements_files_by_platform_bzl",
         ":whl_target_platforms_bzl",
         "//python/private:normalize_name_bzl",
     ],
@@ -228,6 +230,15 @@
 )
 
 bzl_library(
+    name = "pypi_repo_utils_bzl",
+    srcs = ["pypi_repo_utils.bzl"],
+    deps = [
+        "//python:versions_bzl",
+        "//python/private:toolchains_repo_bzl",
+    ],
+)
+
+bzl_library(
     name = "render_pkg_aliases_bzl",
     srcs = ["render_pkg_aliases.bzl"],
     deps = [
@@ -241,6 +252,14 @@
 )
 
 bzl_library(
+    name = "requirements_files_by_platform_bzl",
+    srcs = ["requirements_files_by_platform.bzl"],
+    deps = [
+        ":whl_target_platforms_bzl",
+    ],
+)
+
+bzl_library(
     name = "simpleapi_download_bzl",
     srcs = ["simpleapi_download.bzl"],
     deps = [
@@ -270,13 +289,12 @@
         ":generate_whl_library_build_bazel_bzl",
         ":parse_whl_name_bzl",
         ":patch_whl_bzl",
+        ":pypi_repo_utils_bzl",
         ":whl_target_platforms_bzl",
         "//python:repositories_bzl",
-        "//python:versions_bzl",
         "//python/private:auth_bzl",
         "//python/private:envsubst_bzl",
         "//python/private:repo_utils_bzl",
-        "//python/private:toolchains_repo_bzl",
     ],
 )
 
diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl
index 6aafc71..d837d8d 100644
--- a/python/private/pypi/extension.bzl
+++ b/python/private/pypi/extension.bzl
@@ -26,6 +26,7 @@
 load(":parse_whl_name.bzl", "parse_whl_name")
 load(":pip_repository_attrs.bzl", "ATTRS")
 load(":render_pkg_aliases.bzl", "whl_alias")
+load(":requirements_files_by_platform.bzl", "requirements_files_by_platform")
 load(":simpleapi_download.bzl", "simpleapi_download")
 load(":whl_library.bzl", "whl_library")
 load(":whl_repo_name.bzl", "whl_repo_name")
@@ -183,12 +184,16 @@
 
     requirements_by_platform = parse_requirements(
         module_ctx,
-        requirements_by_platform = pip_attr.requirements_by_platform,
-        requirements_linux = pip_attr.requirements_linux,
-        requirements_lock = pip_attr.requirements_lock,
-        requirements_osx = pip_attr.requirements_darwin,
-        requirements_windows = pip_attr.requirements_windows,
-        extra_pip_args = pip_attr.extra_pip_args,
+        requirements_by_platform = requirements_files_by_platform(
+            requirements_by_platform = pip_attr.requirements_by_platform,
+            requirements_linux = pip_attr.requirements_linux,
+            requirements_lock = pip_attr.requirements_lock,
+            requirements_osx = pip_attr.requirements_darwin,
+            requirements_windows = pip_attr.requirements_windows,
+            extra_pip_args = pip_attr.extra_pip_args,
+            python_version = major_minor,
+            logger = logger,
+        ),
         get_index_urls = get_index_urls,
         python_version = major_minor,
         logger = logger,
@@ -298,7 +303,7 @@
 
         requirement = select_requirement(
             requirements,
-            platform = repository_platform,
+            platform = None if pip_attr.download_only else repository_platform,
         )
         if not requirement:
             # Sometimes the package is not present for host platform if there
diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl
index d52180c..5258153 100644
--- a/python/private/pypi/parse_requirements.bzl
+++ b/python/private/pypi/parse_requirements.bzl
@@ -29,7 +29,7 @@
 load("//python/private:normalize_name.bzl", "normalize_name")
 load(":index_sources.bzl", "index_sources")
 load(":parse_requirements_txt.bzl", "parse_requirements_txt")
-load(":whl_target_platforms.bzl", "select_whls", "whl_target_platforms")
+load(":whl_target_platforms.bzl", "select_whls")
 
 # This includes the vendored _translate_cpu and _translate_os from
 # @platforms//host:extension.bzl at version 0.0.9 so that we don't
@@ -80,72 +80,10 @@
     "windows_x86_64",
 ]
 
-def _default_platforms(*, filter):
-    if not filter:
-        fail("Must specific a filter string, got: {}".format(filter))
-
-    if filter.startswith("cp3"):
-        # TODO @aignas 2024-05-23: properly handle python versions in the filter.
-        # For now we are just dropping it to ensure that we don't fail.
-        _, _, filter = filter.partition("_")
-
-    sanitized = filter.replace("*", "").replace("_", "")
-    if sanitized and not sanitized.isalnum():
-        fail("The platform filter can only contain '*', '_' and alphanumerics")
-
-    if "*" in filter:
-        prefix = filter.rstrip("*")
-        if "*" in prefix:
-            fail("The filter can only contain '*' at the end of it")
-
-        if not prefix:
-            return DEFAULT_PLATFORMS
-
-        return [p for p in DEFAULT_PLATFORMS if p.startswith(prefix)]
-    else:
-        return [p for p in DEFAULT_PLATFORMS if filter in p]
-
-def _platforms_from_args(extra_pip_args):
-    platform_values = []
-
-    for arg in extra_pip_args:
-        if platform_values and platform_values[-1] == "":
-            platform_values[-1] = arg
-            continue
-
-        if arg == "--platform":
-            platform_values.append("")
-            continue
-
-        if not arg.startswith("--platform"):
-            continue
-
-        _, _, plat = arg.partition("=")
-        if not plat:
-            _, _, plat = arg.partition(" ")
-        if plat:
-            platform_values.append(plat)
-        else:
-            platform_values.append("")
-
-    if not platform_values:
-        return []
-
-    platforms = {
-        p.target_platform: None
-        for arg in platform_values
-        for p in whl_target_platforms(arg)
-    }
-    return list(platforms.keys())
-
 def parse_requirements(
         ctx,
         *,
         requirements_by_platform = {},
-        requirements_osx = None,
-        requirements_linux = None,
-        requirements_lock = None,
-        requirements_windows = None,
         extra_pip_args = [],
         get_index_urls = None,
         python_version = None,
@@ -158,10 +96,6 @@
         requirements_by_platform (label_keyed_string_dict): a way to have
             different package versions (or different packages) for different
             os, arch combinations.
-        requirements_osx (label): The requirements file for the osx OS.
-        requirements_linux (label): The requirements file for the linux OS.
-        requirements_lock (label): The requirements file for all OSes, or used as a fallback.
-        requirements_windows (label): The requirements file for windows OS.
         extra_pip_args (string list): Extra pip arguments to perform extra validations and to
             be joined with args fined in files.
         get_index_urls: Callable[[ctx, list[str]], dict], a callable to get all
@@ -186,91 +120,11 @@
 
         The second element is extra_pip_args should be passed to `whl_library`.
     """
-    if not (
-        requirements_lock or
-        requirements_linux or
-        requirements_osx or
-        requirements_windows or
-        requirements_by_platform
-    ):
-        fail_fn(
-            "A 'requirements_lock' attribute must be specified, a platform-specific lockfiles " +
-            "via 'requirements_by_platform' or an os-specific lockfiles must be specified " +
-            "via 'requirements_*' attributes",
-        )
-        return None
-
-    platforms = _platforms_from_args(extra_pip_args)
-
-    if platforms:
-        lock_files = [
-            f
-            for f in [
-                requirements_lock,
-                requirements_linux,
-                requirements_osx,
-                requirements_windows,
-            ] + list(requirements_by_platform.keys())
-            if f
-        ]
-
-        if len(lock_files) > 1:
-            # If the --platform argument is used, check that we are using
-            # a single `requirements_lock` file instead of the OS specific ones as that is
-            # the only correct way to use the API.
-            fail_fn("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute")
-            return None
-
-        files_by_platform = [
-            (lock_files[0], platforms),
-        ]
-    else:
-        files_by_platform = {
-            file: [
-                platform
-                for filter_or_platform in specifier.split(",")
-                for platform in (_default_platforms(filter = filter_or_platform) if filter_or_platform.endswith("*") else [filter_or_platform])
-            ]
-            for file, specifier in requirements_by_platform.items()
-        }.items()
-
-        for f in [
-            # If the users need a greater span of the platforms, they should consider
-            # using the 'requirements_by_platform' attribute.
-            (requirements_linux, _default_platforms(filter = "linux_*")),
-            (requirements_osx, _default_platforms(filter = "osx_*")),
-            (requirements_windows, _default_platforms(filter = "windows_*")),
-            (requirements_lock, None),
-        ]:
-            if f[0]:
-                files_by_platform.append(f)
-
-    configured_platforms = {}
-
     options = {}
     requirements = {}
-    for file, plats in files_by_platform:
-        if plats:
-            for p in plats:
-                if p in configured_platforms:
-                    fail_fn(
-                        "Expected the platform '{}' to be map only to a single requirements file, but got multiple: '{}', '{}'".format(
-                            p,
-                            configured_platforms[p],
-                            file,
-                        ),
-                    )
-                    return None
-                configured_platforms[p] = file
-        else:
-            plats = [
-                p
-                for p in DEFAULT_PLATFORMS
-                if p not in configured_platforms
-            ]
-            for p in plats:
-                configured_platforms[p] = file
-
+    for file, plats in requirements_by_platform.items():
+        if logger:
+            logger.debug(lambda: "Using {} for {}".format(file, plats))
         contents = ctx.read(file)
 
         # Parse the requirements file directly in starlark to get the information
@@ -303,9 +157,9 @@
                 tokenized_options.append(p)
 
         pip_args = tokenized_options + extra_pip_args
-        for p in plats:
-            requirements[p] = requirements_dict
-            options[p] = pip_args
+        for plat in plats:
+            requirements[plat] = requirements_dict
+            options[plat] = pip_args
 
     requirements_by_platform = {}
     for target_platform, reqs_ in requirements.items():
@@ -325,7 +179,6 @@
                     requirement_line = requirement_line,
                     target_platforms = [],
                     extra_pip_args = extra_pip_args,
-                    download = len(platforms) > 0,
                 ),
             )
             for_req.target_platforms.append(target_platform)
@@ -353,12 +206,12 @@
             for p in r.target_platforms:
                 requirement_target_platforms[p] = None
 
-        is_exposed = len(requirement_target_platforms) == len(configured_platforms)
+        is_exposed = len(requirement_target_platforms) == len(requirements)
         if not is_exposed and logger:
-            logger.debug(lambda: "Package {} will not be exposed because it is only present on a subset of platforms: {} out of {}".format(
+            logger.debug(lambda: "Package '{}' will not be exposed because it is only present on a subset of platforms: {} out of {}".format(
                 whl_name,
                 sorted(requirement_target_platforms),
-                sorted(configured_platforms),
+                sorted(requirements),
             ))
 
         for r in sorted(reqs.values(), key = lambda r: r.requirement_line):
@@ -376,13 +229,15 @@
                     requirement_line = r.requirement_line,
                     target_platforms = sorted(r.target_platforms),
                     extra_pip_args = r.extra_pip_args,
-                    download = r.download,
                     whls = whls,
                     sdist = sdist,
                     is_exposed = is_exposed,
                 ),
             )
 
+    if logger:
+        logger.debug(lambda: "Will configure whl repos: {}".format(ret.keys()))
+
     return ret
 
 def select_requirement(requirements, *, platform):
@@ -391,8 +246,9 @@
     Args:
         requirements (list[struct]): The list of requirements as returned by
             the `parse_requirements` function above.
-        platform (str): The host platform. Usually an output of the
-        `host_platform` function.
+        platform (str or None): The host platform. Usually an output of the
+            `host_platform` function. If None, then this function will return
+            the first requirement it finds.
 
     Returns:
         None if not found or a struct returned as one of the values in the
@@ -402,7 +258,7 @@
     maybe_requirement = [
         req
         for req in requirements
-        if platform in req.target_platforms or req.download
+        if not platform or [p for p in req.target_platforms if p.endswith(platform)]
     ]
     if not maybe_requirement:
         # Sometimes the package is not present for host platform if there
diff --git a/python/private/pypi/pip_repository.bzl b/python/private/pypi/pip_repository.bzl
index a22f4d9..42622c3 100644
--- a/python/private/pypi/pip_repository.bzl
+++ b/python/private/pypi/pip_repository.bzl
@@ -21,6 +21,7 @@
 load(":parse_requirements.bzl", "host_platform", "parse_requirements", "select_requirement")
 load(":pip_repository_attrs.bzl", "ATTRS")
 load(":render_pkg_aliases.bzl", "render_pkg_aliases", "whl_alias")
+load(":requirements_files_by_platform.bzl", "requirements_files_by_platform")
 
 def _get_python_interpreter_attr(rctx):
     """A helper function for getting the `python_interpreter` attribute or it's default
@@ -71,11 +72,14 @@
 def _pip_repository_impl(rctx):
     requirements_by_platform = parse_requirements(
         rctx,
-        requirements_by_platform = rctx.attr.requirements_by_platform,
-        requirements_linux = rctx.attr.requirements_linux,
-        requirements_lock = rctx.attr.requirements_lock,
-        requirements_osx = rctx.attr.requirements_darwin,
-        requirements_windows = rctx.attr.requirements_windows,
+        requirements_by_platform = requirements_files_by_platform(
+            requirements_by_platform = rctx.attr.requirements_by_platform,
+            requirements_linux = rctx.attr.requirements_linux,
+            requirements_lock = rctx.attr.requirements_lock,
+            requirements_osx = rctx.attr.requirements_darwin,
+            requirements_windows = rctx.attr.requirements_windows,
+            extra_pip_args = rctx.attr.extra_pip_args,
+        ),
         extra_pip_args = rctx.attr.extra_pip_args,
     )
     selected_requirements = {}
@@ -84,7 +88,7 @@
     for name, requirements in requirements_by_platform.items():
         r = select_requirement(
             requirements,
-            platform = repository_platform,
+            platform = None if rctx.attr.download_only else repository_platform,
         )
         if not r:
             continue
diff --git a/python/private/pypi/pypi_repo_utils.bzl b/python/private/pypi/pypi_repo_utils.bzl
new file mode 100644
index 0000000..6e5d93b
--- /dev/null
+++ b/python/private/pypi/pypi_repo_utils.bzl
@@ -0,0 +1,94 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""
+
+load("//python:versions.bzl", "WINDOWS_NAME")
+load("//python/private:toolchains_repo.bzl", "get_host_os_arch")
+
+def _get_python_interpreter_attr(ctx, *, python_interpreter = None):
+    """A helper function for getting the `python_interpreter` attribute or it's default
+
+    Args:
+        ctx (repository_ctx): Handle to the rule repository context.
+        python_interpreter (str): The python interpreter override.
+
+    Returns:
+        str: The attribute value or it's default
+    """
+    if python_interpreter:
+        return python_interpreter
+
+    if "win" in ctx.os.name:
+        return "python.exe"
+    else:
+        return "python3"
+
+def _resolve_python_interpreter(ctx, *, python_interpreter = None, python_interpreter_target = None):
+    """Helper function to find the python interpreter from the common attributes
+
+    Args:
+        ctx: Handle to the rule repository context.
+        python_interpreter: The python interpreter to use.
+        python_interpreter_target: The python interpreter to use after downloading the label.
+
+    Returns:
+        `path` object, for the resolved path to the Python interpreter.
+    """
+    python_interpreter = _get_python_interpreter_attr(ctx, python_interpreter = python_interpreter)
+
+    if python_interpreter_target != None:
+        python_interpreter = ctx.path(python_interpreter_target)
+
+        (os, _) = get_host_os_arch(ctx)
+
+        # On Windows, the symlink doesn't work because Windows attempts to find
+        # Python DLLs where the symlink is, not where the symlink points.
+        if os == WINDOWS_NAME:
+            python_interpreter = python_interpreter.realpath
+    elif "/" not in python_interpreter:
+        # It's a plain command, e.g. "python3", to look up in the environment.
+        found_python_interpreter = ctx.which(python_interpreter)
+        if not found_python_interpreter:
+            fail("python interpreter `{}` not found in PATH".format(python_interpreter))
+        python_interpreter = found_python_interpreter
+    else:
+        python_interpreter = ctx.path(python_interpreter)
+    return python_interpreter
+
+def _construct_pypath(rctx, *, entries):
+    """Helper function to construct a PYTHONPATH.
+
+    Contains entries for code in this repo as well as packages downloaded from //python/pip_install:repositories.bzl.
+    This allows us to run python code inside repository rule implementations.
+
+    Args:
+        rctx: Handle to the repository_context.
+        entries: The list of entries to add to PYTHONPATH.
+
+    Returns: String of the PYTHONPATH.
+    """
+
+    separator = ":" if not "windows" in rctx.os.name.lower() else ";"
+    pypath = separator.join([
+        str(rctx.path(entry).dirname)
+        # Use a dict as a way to remove duplicates and then sort it.
+        for entry in sorted({x: None for x in entries})
+    ])
+    return pypath
+
+pypi_repo_utils = struct(
+    resolve_python_interpreter = _resolve_python_interpreter,
+    construct_pythonpath = _construct_pypath,
+)
diff --git a/python/private/pypi/render_pkg_aliases.bzl b/python/private/pypi/render_pkg_aliases.bzl
index eb907fe..9e5158f 100644
--- a/python/private/pypi/render_pkg_aliases.bzl
+++ b/python/private/pypi/render_pkg_aliases.bzl
@@ -265,6 +265,11 @@
         config_setting = config_setting or ("//_config:is_python_" + version)
         config_setting = str(config_setting)
 
+    if target_platforms:
+        for p in target_platforms:
+            if not p.startswith("cp"):
+                fail("target_platform should start with 'cp' denoting the python version, got: " + p)
+
     return struct(
         repo = repo,
         version = version,
@@ -448,7 +453,7 @@
             parsed = parse_whl_name(a.filename)
         else:
             for plat in a.target_platforms or []:
-                target_platforms[plat] = None
+                target_platforms[_non_versioned_platform(plat)] = None
             continue
 
         for platform_tag in parsed.platform_tag.split("."):
@@ -486,6 +491,19 @@
         if v
     }
 
+def _non_versioned_platform(p, *, strict = False):
+    """A small utility function that converts 'cp311_linux_x86_64' to 'linux_x86_64'.
+
+    This is so that we can tighten the code structure later by using strict = True.
+    """
+    has_abi = p.startswith("cp")
+    if has_abi:
+        return p.partition("_")[-1]
+    elif not strict:
+        return p
+    else:
+        fail("Expected to always have a platform in the form '{{abi}}_{{os}}_{{arch}}', got: {}".format(p))
+
 def get_filename_config_settings(
         *,
         filename,
@@ -499,7 +517,7 @@
 
     Args:
         filename: the distribution filename (can be a whl or an sdist).
-        target_platforms: list[str], target platforms in "{os}_{cpu}" format.
+        target_platforms: list[str], target platforms in "{abi}_{os}_{cpu}" format.
         glibc_versions: list[tuple[int, int]], list of versions.
         muslc_versions: list[tuple[int, int]], list of versions.
         osx_versions: list[tuple[int, int]], list of versions.
@@ -541,7 +559,7 @@
 
         if parsed.platform_tag == "any":
             prefixes = ["{}_{}_any".format(py, abi)]
-            suffixes = target_platforms
+            suffixes = [_non_versioned_platform(p) for p in target_platforms or []]
         else:
             prefixes = ["{}_{}".format(py, abi)]
             suffixes = _whl_config_setting_suffixes(
@@ -553,7 +571,7 @@
             )
     else:
         prefixes = ["sdist"]
-        suffixes = target_platforms
+        suffixes = [_non_versioned_platform(p) for p in target_platforms or []]
 
     if python_default and python_version:
         prefixes += ["cp{}_{}".format(python_version, p) for p in prefixes]
diff --git a/python/private/pypi/requirements_files_by_platform.bzl b/python/private/pypi/requirements_files_by_platform.bzl
new file mode 100644
index 0000000..e3aafc0
--- /dev/null
+++ b/python/private/pypi/requirements_files_by_platform.bzl
@@ -0,0 +1,258 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Get the requirement files by platform."""
+
+load(":whl_target_platforms.bzl", "whl_target_platforms")
+
+# TODO @aignas 2024-05-13: consider using the same platform tags as are used in
+# the //python:versions.bzl
+DEFAULT_PLATFORMS = [
+    "linux_aarch64",
+    "linux_arm",
+    "linux_ppc",
+    "linux_s390x",
+    "linux_x86_64",
+    "osx_aarch64",
+    "osx_x86_64",
+    "windows_x86_64",
+]
+
+def _default_platforms(*, filter):
+    if not filter:
+        fail("Must specific a filter string, got: {}".format(filter))
+
+    if filter.startswith("cp3"):
+        # TODO @aignas 2024-05-23: properly handle python versions in the filter.
+        # For now we are just dropping it to ensure that we don't fail.
+        _, _, filter = filter.partition("_")
+
+    sanitized = filter.replace("*", "").replace("_", "")
+    if sanitized and not sanitized.isalnum():
+        fail("The platform filter can only contain '*', '_' and alphanumerics")
+
+    if "*" in filter:
+        prefix = filter.rstrip("*")
+        if "*" in prefix:
+            fail("The filter can only contain '*' at the end of it")
+
+        if not prefix:
+            return DEFAULT_PLATFORMS
+
+        return [p for p in DEFAULT_PLATFORMS if p.startswith(prefix)]
+    else:
+        return [p for p in DEFAULT_PLATFORMS if filter in p]
+
+def _platforms_from_args(extra_pip_args):
+    platform_values = []
+
+    if not extra_pip_args:
+        return platform_values
+
+    for arg in extra_pip_args:
+        if platform_values and platform_values[-1] == "":
+            platform_values[-1] = arg
+            continue
+
+        if arg == "--platform":
+            platform_values.append("")
+            continue
+
+        if not arg.startswith("--platform"):
+            continue
+
+        _, _, plat = arg.partition("=")
+        if not plat:
+            _, _, plat = arg.partition(" ")
+        if plat:
+            platform_values.append(plat)
+        else:
+            platform_values.append("")
+
+    if not platform_values:
+        return []
+
+    platforms = {
+        p.target_platform: None
+        for arg in platform_values
+        for p in whl_target_platforms(arg)
+    }
+    return list(platforms.keys())
+
+def _platform(platform_string, python_version = None):
+    if not python_version or platform_string.startswith("cp3"):
+        return platform_string
+
+    _, _, tail = python_version.partition(".")
+    minor, _, _ = tail.partition(".")
+
+    return "cp3{}_{}".format(minor, platform_string)
+
+def requirements_files_by_platform(
+        *,
+        requirements_by_platform = {},
+        requirements_osx = None,
+        requirements_linux = None,
+        requirements_lock = None,
+        requirements_windows = None,
+        extra_pip_args = None,
+        python_version = None,
+        logger = None,
+        fail_fn = fail):
+    """Resolve the requirement files by target platform.
+
+    Args:
+        requirements_by_platform (label_keyed_string_dict): a way to have
+            different package versions (or different packages) for different
+            os, arch combinations.
+        requirements_osx (label): The requirements file for the osx OS.
+        requirements_linux (label): The requirements file for the linux OS.
+        requirements_lock (label): The requirements file for all OSes, or used as a fallback.
+        requirements_windows (label): The requirements file for windows OS.
+        extra_pip_args (string list): Extra pip arguments to perform extra validations and to
+            be joined with args fined in files.
+        python_version: str or None. This is needed when the get_index_urls is
+            specified. It should be of the form "3.x.x",
+        logger: repo_utils.logger or None, a simple struct to log diagnostic messages.
+        fail_fn (Callable[[str], None]): A failure function used in testing failure cases.
+
+    Returns:
+        A dict with keys as the labels to the files and values as lists of
+        platforms that the files support.
+    """
+    if not (
+        requirements_lock or
+        requirements_linux or
+        requirements_osx or
+        requirements_windows or
+        requirements_by_platform
+    ):
+        fail_fn(
+            "A 'requirements_lock' attribute must be specified, a platform-specific lockfiles " +
+            "via 'requirements_by_platform' or an os-specific lockfiles must be specified " +
+            "via 'requirements_*' attributes",
+        )
+        return None
+
+    platforms = _platforms_from_args(extra_pip_args)
+    if logger:
+        logger.debug(lambda: "Platforms from pip args: {}".format(platforms))
+
+    if platforms:
+        lock_files = [
+            f
+            for f in [
+                requirements_lock,
+                requirements_linux,
+                requirements_osx,
+                requirements_windows,
+            ] + list(requirements_by_platform.keys())
+            if f
+        ]
+
+        if len(lock_files) > 1:
+            # If the --platform argument is used, check that we are using
+            # a single `requirements_lock` file instead of the OS specific ones as that is
+            # the only correct way to use the API.
+            fail_fn("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute")
+            return None
+
+        files_by_platform = [
+            (lock_files[0], platforms),
+        ]
+        if logger:
+            logger.debug(lambda: "Files by platform with the platform set in the args: {}".format(files_by_platform))
+    else:
+        files_by_platform = {
+            file: [
+                platform
+                for filter_or_platform in specifier.split(",")
+                for platform in (_default_platforms(filter = filter_or_platform) if filter_or_platform.endswith("*") else [filter_or_platform])
+            ]
+            for file, specifier in requirements_by_platform.items()
+        }.items()
+
+        if logger:
+            logger.debug(lambda: "Files by platform with the platform set in the attrs: {}".format(files_by_platform))
+
+        for f in [
+            # If the users need a greater span of the platforms, they should consider
+            # using the 'requirements_by_platform' attribute.
+            (requirements_linux, _default_platforms(filter = "linux_*")),
+            (requirements_osx, _default_platforms(filter = "osx_*")),
+            (requirements_windows, _default_platforms(filter = "windows_*")),
+            (requirements_lock, None),
+        ]:
+            if f[0]:
+                if logger:
+                    logger.debug(lambda: "Adding an extra item to files_by_platform: {}".format(f))
+                files_by_platform.append(f)
+
+    configured_platforms = {}
+    requirements = {}
+    for file, plats in files_by_platform:
+        if plats:
+            plats = [_platform(p, python_version) for p in plats]
+            for p in plats:
+                if p in configured_platforms:
+                    fail_fn(
+                        "Expected the platform '{}' to be map only to a single requirements file, but got multiple: '{}', '{}'".format(
+                            p,
+                            configured_platforms[p],
+                            file,
+                        ),
+                    )
+                    return None
+
+                configured_platforms[p] = file
+        else:
+            default_platforms = [_platform(p, python_version) for p in DEFAULT_PLATFORMS]
+            plats = [
+                p
+                for p in default_platforms
+                if p not in configured_platforms
+            ]
+            if logger:
+                logger.debug(lambda: "File {} will be used for the remaining platforms {} that are not in configured_platforms: {}".format(
+                    file,
+                    plats,
+                    default_platforms,
+                ))
+            for p in plats:
+                configured_platforms[p] = file
+
+        if logger:
+            logger.debug(lambda: "Configured platforms for file {} are {}".format(file, plats))
+
+        for p in plats:
+            if p in requirements:
+                # This should never happen because in the code above we should
+                # have unambiguous selection of the requirements files.
+                fail_fn("Attempting to override a requirements file '{}' with '{}' for platform '{}'".format(
+                    requirements[p],
+                    file,
+                    p,
+                ))
+                return None
+            requirements[p] = file
+
+    # Now return a dict that is similar to requirements_by_platform - where we
+    # have labels/files as keys in the dict to minimize the number of times we
+    # may parse the same file.
+
+    ret = {}
+    for plat, file in requirements.items():
+        ret.setdefault(file, []).append(plat)
+
+    return ret
diff --git a/python/private/pypi/whl_library.bzl b/python/private/pypi/whl_library.bzl
index a3fa1d8..f453f92 100644
--- a/python/private/pypi/whl_library.bzl
+++ b/python/private/pypi/whl_library.bzl
@@ -15,88 +15,21 @@
 ""
 
 load("//python:repositories.bzl", "is_standalone_interpreter")
-load("//python:versions.bzl", "WINDOWS_NAME")
 load("//python/private:auth.bzl", "AUTH_ATTRS", "get_auth")
 load("//python/private:envsubst.bzl", "envsubst")
 load("//python/private:repo_utils.bzl", "REPO_DEBUG_ENV_VAR", "repo_utils")
-load("//python/private:toolchains_repo.bzl", "get_host_os_arch")
 load(":attrs.bzl", "ATTRS", "use_isolated")
 load(":deps.bzl", "all_repo_names")
 load(":generate_whl_library_build_bazel.bzl", "generate_whl_library_build_bazel")
 load(":parse_whl_name.bzl", "parse_whl_name")
 load(":patch_whl.bzl", "patch_whl")
+load(":pypi_repo_utils.bzl", "pypi_repo_utils")
 load(":whl_target_platforms.bzl", "whl_target_platforms")
 
 _CPPFLAGS = "CPPFLAGS"
 _COMMAND_LINE_TOOLS_PATH_SLUG = "commandlinetools"
 _WHEEL_ENTRY_POINT_PREFIX = "rules_python_wheel_entry_point"
 
-def _construct_pypath(rctx):
-    """Helper function to construct a PYTHONPATH.
-
-    Contains entries for code in this repo as well as packages downloaded from //python/pip_install:repositories.bzl.
-    This allows us to run python code inside repository rule implementations.
-
-    Args:
-        rctx: Handle to the repository_context.
-
-    Returns: String of the PYTHONPATH.
-    """
-
-    separator = ":" if not "windows" in rctx.os.name.lower() else ";"
-    pypath = separator.join([
-        str(rctx.path(entry).dirname)
-        for entry in rctx.attr._python_path_entries
-    ])
-    return pypath
-
-def _get_python_interpreter_attr(rctx):
-    """A helper function for getting the `python_interpreter` attribute or it's default
-
-    Args:
-        rctx (repository_ctx): Handle to the rule repository context.
-
-    Returns:
-        str: The attribute value or it's default
-    """
-    if rctx.attr.python_interpreter:
-        return rctx.attr.python_interpreter
-
-    if "win" in rctx.os.name:
-        return "python.exe"
-    else:
-        return "python3"
-
-def _resolve_python_interpreter(rctx):
-    """Helper function to find the python interpreter from the common attributes
-
-    Args:
-        rctx: Handle to the rule repository context.
-
-    Returns:
-        `path` object, for the resolved path to the Python interpreter.
-    """
-    python_interpreter = _get_python_interpreter_attr(rctx)
-
-    if rctx.attr.python_interpreter_target != None:
-        python_interpreter = rctx.path(rctx.attr.python_interpreter_target)
-
-        (os, _) = get_host_os_arch(rctx)
-
-        # On Windows, the symlink doesn't work because Windows attempts to find
-        # Python DLLs where the symlink is, not where the symlink points.
-        if os == WINDOWS_NAME:
-            python_interpreter = python_interpreter.realpath
-    elif "/" not in python_interpreter:
-        # It's a plain command, e.g. "python3", to look up in the environment.
-        found_python_interpreter = rctx.which(python_interpreter)
-        if not found_python_interpreter:
-            fail("python interpreter `{}` not found in PATH".format(python_interpreter))
-        python_interpreter = found_python_interpreter
-    else:
-        python_interpreter = rctx.path(python_interpreter)
-    return python_interpreter
-
 def _get_xcode_location_cflags(rctx):
     """Query the xcode sdk location to update cflags
 
@@ -230,14 +163,21 @@
     cppflags.extend(_get_toolchain_unix_cflags(rctx, python_interpreter))
 
     env = {
-        "PYTHONPATH": _construct_pypath(rctx),
+        "PYTHONPATH": pypi_repo_utils.construct_pythonpath(
+            rctx,
+            entries = rctx.attr._python_path_entries,
+        ),
         _CPPFLAGS: " ".join(cppflags),
     }
 
     return env
 
 def _whl_library_impl(rctx):
-    python_interpreter = _resolve_python_interpreter(rctx)
+    python_interpreter = pypi_repo_utils.resolve_python_interpreter(
+        rctx,
+        python_interpreter = rctx.attr.python_interpreter,
+        python_interpreter_target = rctx.attr.python_interpreter_target,
+    )
     args = [
         python_interpreter,
         "-m",
diff --git a/tests/pypi/parse_requirements/parse_requirements_tests.bzl b/tests/pypi/parse_requirements/parse_requirements_tests.bzl
index 5c33dd8..1a7143b 100644
--- a/tests/pypi/parse_requirements/parse_requirements_tests.bzl
+++ b/tests/pypi/parse_requirements/parse_requirements_tests.bzl
@@ -52,33 +52,17 @@
 
 _tests = []
 
-def _test_fail_no_requirements(env):
-    errors = []
-    parse_requirements(
-        ctx = _mock_ctx(),
-        fail_fn = errors.append,
-    )
-    env.expect.that_str(errors[0]).equals("""\
-A 'requirements_lock' attribute must be specified, a platform-specific lockfiles via 'requirements_by_platform' or an os-specific lockfiles must be specified via 'requirements_*' attributes""")
-
-_tests.append(_test_fail_no_requirements)
-
 def _test_simple(env):
     got = parse_requirements(
         ctx = _mock_ctx(),
-        requirements_lock = "requirements_lock",
-    )
-    got_alternative = parse_requirements(
-        ctx = _mock_ctx(),
         requirements_by_platform = {
-            "requirements_lock": "*",
+            "requirements_lock": ["linux_x86_64", "windows_x86_64"],
         },
     )
     env.expect.that_dict(got).contains_exactly({
         "foo": [
             struct(
                 distribution = "foo",
-                download = False,
                 extra_pip_args = [],
                 requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef",
                 srcs = struct(
@@ -87,13 +71,7 @@
                     version = "0.0.1",
                 ),
                 target_platforms = [
-                    "linux_aarch64",
-                    "linux_arm",
-                    "linux_ppc",
-                    "linux_s390x",
                     "linux_x86_64",
-                    "osx_aarch64",
-                    "osx_x86_64",
                     "windows_x86_64",
                 ],
                 whls = [],
@@ -102,68 +80,26 @@
             ),
         ],
     })
-    env.expect.that_dict(got).contains_exactly(got_alternative)
     env.expect.that_str(
         select_requirement(
             got["foo"],
-            platform = "linux_ppc",
+            platform = "linux_x86_64",
         ).srcs.version,
     ).equals("0.0.1")
 
 _tests.append(_test_simple)
 
-def _test_platform_markers_with_python_version(env):
-    got = parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_by_platform = {
-            "requirements_lock": "cp39_linux_*",
-        },
-    )
-    got_alternative = parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_by_platform = {
-            "requirements_lock": "linux_*",
-        },
-    )
-    env.expect.that_dict(got).contains_exactly({
-        "foo": [
-            struct(
-                distribution = "foo",
-                download = False,
-                extra_pip_args = [],
-                requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef",
-                srcs = struct(
-                    requirement = "foo[extra]==0.0.1",
-                    shas = ["deadbeef"],
-                    version = "0.0.1",
-                ),
-                target_platforms = [
-                    "linux_aarch64",
-                    "linux_arm",
-                    "linux_ppc",
-                    "linux_s390x",
-                    "linux_x86_64",
-                ],
-                whls = [],
-                sdist = None,
-                is_exposed = True,
-            ),
-        ],
-    })
-    env.expect.that_dict(got).contains_exactly(got_alternative)
-
-_tests.append(_test_platform_markers_with_python_version)
-
 def _test_dupe_requirements(env):
     got = parse_requirements(
         ctx = _mock_ctx(),
-        requirements_lock = "requirements_lock_dupe",
+        requirements_by_platform = {
+            "requirements_lock_dupe": ["linux_x86_64"],
+        },
     )
     env.expect.that_dict(got).contains_exactly({
         "foo": [
             struct(
                 distribution = "foo",
-                download = False,
                 extra_pip_args = [],
                 requirement_line = "foo[extra,extra_2]==0.0.1 --hash=sha256:deadbeef",
                 srcs = struct(
@@ -171,16 +107,7 @@
                     shas = ["deadbeef"],
                     version = "0.0.1",
                 ),
-                target_platforms = [
-                    "linux_aarch64",
-                    "linux_arm",
-                    "linux_ppc",
-                    "linux_s390x",
-                    "linux_x86_64",
-                    "osx_aarch64",
-                    "osx_x86_64",
-                    "windows_x86_64",
-                ],
+                target_platforms = ["linux_x86_64"],
                 whls = [],
                 sdist = None,
                 is_exposed = True,
@@ -193,18 +120,9 @@
 def _test_multi_os(env):
     got = parse_requirements(
         ctx = _mock_ctx(),
-        requirements_linux = "requirements_linux",
-        requirements_osx = "requirements_osx",
-        requirements_windows = "requirements_windows",
-    )
-
-    # This is an alternative way to express the same intent
-    got_alternative = parse_requirements(
-        ctx = _mock_ctx(),
         requirements_by_platform = {
-            "requirements_linux": "linux_*",
-            "requirements_osx": "osx_*",
-            "requirements_windows": "windows_*",
+            "requirements_linux": ["linux_x86_64"],
+            "requirements_windows": ["windows_x86_64"],
         },
     )
 
@@ -212,7 +130,6 @@
         "bar": [
             struct(
                 distribution = "bar",
-                download = False,
                 extra_pip_args = [],
                 requirement_line = "bar==0.0.1 --hash=sha256:deadb00f",
                 srcs = struct(
@@ -229,7 +146,6 @@
         "foo": [
             struct(
                 distribution = "foo",
-                download = False,
                 extra_pip_args = [],
                 requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf",
                 srcs = struct(
@@ -237,22 +153,13 @@
                     shas = ["deadbaaf"],
                     version = "0.0.3",
                 ),
-                target_platforms = [
-                    "linux_aarch64",
-                    "linux_arm",
-                    "linux_ppc",
-                    "linux_s390x",
-                    "linux_x86_64",
-                    "osx_aarch64",
-                    "osx_x86_64",
-                ],
+                target_platforms = ["linux_x86_64"],
                 whls = [],
                 sdist = None,
                 is_exposed = True,
             ),
             struct(
                 distribution = "foo",
-                download = False,
                 extra_pip_args = [],
                 requirement_line = "foo[extra]==0.0.2 --hash=sha256:deadbeef",
                 srcs = struct(
@@ -267,7 +174,6 @@
             ),
         ],
     })
-    env.expect.that_dict(got).contains_exactly(got_alternative)
     env.expect.that_str(
         select_requirement(
             got["foo"],
@@ -277,168 +183,27 @@
 
 _tests.append(_test_multi_os)
 
-def _test_fail_duplicate_platforms(env):
-    errors = []
-    parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_by_platform = {
-            "requirements_linux": "linux_x86_64",
-            "requirements_lock": "*",
-        },
-        fail_fn = errors.append,
-    )
-    env.expect.that_collection(errors).has_size(1)
-    env.expect.that_str(",".join(errors)).equals("Expected the platform 'linux_x86_64' to be map only to a single requirements file, but got multiple: 'requirements_linux', 'requirements_lock'")
-
-_tests.append(_test_fail_duplicate_platforms)
-
-def _test_multi_os_download_only_platform(env):
-    got = parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_lock = "requirements_linux",
-        extra_pip_args = [
-            "--platform",
-            "manylinux_2_27_x86_64",
-            "--platform=manylinux_2_12_x86_64",
-            "--platform manylinux_2_5_x86_64",
-        ],
-    )
-    env.expect.that_dict(got).contains_exactly({
-        "foo": [
+def _test_select_requirement_none_platform(env):
+    got = select_requirement(
+        [
             struct(
-                distribution = "foo",
-                download = True,
-                extra_pip_args = [
-                    "--platform",
-                    "manylinux_2_27_x86_64",
-                    "--platform=manylinux_2_12_x86_64",
-                    "--platform manylinux_2_5_x86_64",
-                ],
-                requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf",
-                srcs = struct(
-                    requirement = "foo==0.0.3",
-                    shas = ["deadbaaf"],
-                    version = "0.0.3",
-                ),
+                some_attr = "foo",
                 target_platforms = ["linux_x86_64"],
-                whls = [],
-                sdist = None,
-                is_exposed = True,
             ),
         ],
-    })
-    env.expect.that_str(
-        select_requirement(
-            got["foo"],
-            platform = "windows_x86_64",
-        ).srcs.version,
-    ).equals("0.0.3")
-
-_tests.append(_test_multi_os_download_only_platform)
-
-def _test_fail_download_only_bad_attr(env):
-    errors = []
-    parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_linux = "requirements_linux",
-        requirements_osx = "requirements_osx",
-        extra_pip_args = [
-            "--platform",
-            "manylinux_2_27_x86_64",
-            "--platform=manylinux_2_12_x86_64",
-            "--platform manylinux_2_5_x86_64",
-        ],
-        fail_fn = errors.append,
+        platform = None,
     )
-    env.expect.that_str(errors[0]).equals("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute")
+    env.expect.that_str(got.some_attr).equals("foo")
 
-_tests.append(_test_fail_download_only_bad_attr)
-
-def _test_os_arch_requirements_with_default(env):
-    got = parse_requirements(
-        ctx = _mock_ctx(),
-        requirements_by_platform = {
-            "requirements_direct": "linux_super_exotic",
-            "requirements_linux": "linux_x86_64,linux_aarch64",
-        },
-        requirements_lock = "requirements_lock",
-    )
-    env.expect.that_dict(got).contains_exactly({
-        "foo": [
-            struct(
-                distribution = "foo",
-                download = False,
-                extra_pip_args = [],
-                requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf",
-                srcs = struct(
-                    requirement = "foo==0.0.3",
-                    shas = ["deadbaaf"],
-                    version = "0.0.3",
-                ),
-                target_platforms = ["linux_aarch64", "linux_x86_64"],
-                whls = [],
-                sdist = None,
-                is_exposed = True,
-            ),
-            struct(
-                distribution = "foo",
-                download = False,
-                extra_pip_args = [],
-                requirement_line = "foo[extra] @ https://some-url",
-                srcs = struct(
-                    requirement = "foo[extra] @ https://some-url",
-                    shas = [],
-                    version = "",
-                ),
-                target_platforms = ["linux_super_exotic"],
-                whls = [],
-                sdist = None,
-                is_exposed = True,
-            ),
-            struct(
-                distribution = "foo",
-                download = False,
-                extra_pip_args = [],
-                requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef",
-                srcs = struct(
-                    requirement = "foo[extra]==0.0.1",
-                    shas = ["deadbeef"],
-                    version = "0.0.1",
-                ),
-                target_platforms = [
-                    "linux_arm",
-                    "linux_ppc",
-                    "linux_s390x",
-                    "osx_aarch64",
-                    "osx_x86_64",
-                    "windows_x86_64",
-                ],
-                whls = [],
-                sdist = None,
-                is_exposed = True,
-            ),
-        ],
-    })
-    env.expect.that_str(
-        select_requirement(
-            got["foo"],
-            platform = "windows_x86_64",
-        ).srcs.version,
-    ).equals("0.0.1")
-    env.expect.that_str(
-        select_requirement(
-            got["foo"],
-            platform = "linux_x86_64",
-        ).srcs.version,
-    ).equals("0.0.3")
-
-_tests.append(_test_os_arch_requirements_with_default)
+_tests.append(_test_select_requirement_none_platform)
 
 def _test_fail_no_python_version(env):
     errors = []
     parse_requirements(
         ctx = _mock_ctx(),
-        requirements_lock = "requirements_lock",
+        requirements_by_platform = {
+            "requirements_lock": [""],
+        },
         get_index_urls = lambda _, __: {},
         fail_fn = errors.append,
     )
diff --git a/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl b/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl
index 0d4c75e..09a0631 100644
--- a/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl
+++ b/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl
@@ -517,7 +517,7 @@
 
 _tests.append(_test_get_python_versions_from_filenames)
 
-def _test_target_platforms_from_alias_target_platforms(env):
+def _test_get_flag_versions_from_alias_target_platforms(env):
     got = get_whl_flag_versions(
         aliases = [
             whl_alias(
@@ -534,7 +534,7 @@
                 version = "3.3",
                 filename = "foo-0.0.0-py3-none-any.whl",
                 target_platforms = [
-                    "linux_x86_64",
+                    "cp33_linux_x86_64",
                 ],
             ),
         ],
@@ -548,7 +548,7 @@
     }
     env.expect.that_dict(got).contains_exactly(want)
 
-_tests.append(_test_target_platforms_from_alias_target_platforms)
+_tests.append(_test_get_flag_versions_from_alias_target_platforms)
 
 def _test_config_settings(
         env,
@@ -820,8 +820,8 @@
             filename = "foo-0.0.2-py3-none-any.whl",
             version = "3.1",
             target_platforms = [
-                "linux_x86_64",
-                "linux_aarch64",
+                "cp31_linux_x86_64",
+                "cp31_linux_aarch64",
             ],
         ),
     ]
diff --git a/tests/pypi/requirements_files_by_platform/BUILD.bazel b/tests/pypi/requirements_files_by_platform/BUILD.bazel
new file mode 100644
index 0000000..d78d459
--- /dev/null
+++ b/tests/pypi/requirements_files_by_platform/BUILD.bazel
@@ -0,0 +1,3 @@
+load(":requirements_files_by_platform_tests.bzl", "requirements_files_by_platform_test_suite")
+
+requirements_files_by_platform_test_suite(name = "requirements_files_by_platform_tests")
diff --git a/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl b/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl
new file mode 100644
index 0000000..b729b0e
--- /dev/null
+++ b/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl
@@ -0,0 +1,205 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""
+
+load("@rules_testing//lib:test_suite.bzl", "test_suite")
+load("//python/private/pypi:requirements_files_by_platform.bzl", "requirements_files_by_platform")  # buildifier: disable=bzl-visibility
+
+_tests = []
+
+def _test_fail_no_requirements(env):
+    errors = []
+    requirements_files_by_platform(
+        fail_fn = errors.append,
+    )
+    env.expect.that_str(errors[0]).equals("""\
+A 'requirements_lock' attribute must be specified, a platform-specific lockfiles via 'requirements_by_platform' or an os-specific lockfiles must be specified via 'requirements_*' attributes""")
+
+_tests.append(_test_fail_no_requirements)
+
+def _test_fail_duplicate_platforms(env):
+    errors = []
+    requirements_files_by_platform(
+        requirements_by_platform = {
+            "requirements_linux": "linux_x86_64",
+            "requirements_lock": "*",
+        },
+        fail_fn = errors.append,
+    )
+    env.expect.that_collection(errors).has_size(1)
+    env.expect.that_str(",".join(errors)).equals("Expected the platform 'linux_x86_64' to be map only to a single requirements file, but got multiple: 'requirements_linux', 'requirements_lock'")
+
+_tests.append(_test_fail_duplicate_platforms)
+
+def _test_fail_download_only_bad_attr(env):
+    errors = []
+    requirements_files_by_platform(
+        requirements_linux = "requirements_linux",
+        requirements_osx = "requirements_osx",
+        extra_pip_args = [
+            "--platform",
+            "manylinux_2_27_x86_64",
+            "--platform=manylinux_2_12_x86_64",
+            "--platform manylinux_2_5_x86_64",
+        ],
+        fail_fn = errors.append,
+    )
+    env.expect.that_str(errors[0]).equals("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute")
+
+_tests.append(_test_fail_download_only_bad_attr)
+
+def _test_simple(env):
+    for got in [
+        requirements_files_by_platform(
+            requirements_lock = "requirements_lock",
+        ),
+        requirements_files_by_platform(
+            requirements_by_platform = {
+                "requirements_lock": "*",
+            },
+        ),
+    ]:
+        env.expect.that_dict(got).contains_exactly({
+            "requirements_lock": [
+                "linux_aarch64",
+                "linux_arm",
+                "linux_ppc",
+                "linux_s390x",
+                "linux_x86_64",
+                "osx_aarch64",
+                "osx_x86_64",
+                "windows_x86_64",
+            ],
+        })
+
+_tests.append(_test_simple)
+
+def _test_simple_with_python_version(env):
+    for got in [
+        requirements_files_by_platform(
+            requirements_lock = "requirements_lock",
+            python_version = "3.11",
+        ),
+        requirements_files_by_platform(
+            requirements_by_platform = {
+                "requirements_lock": "*",
+            },
+            python_version = "3.11",
+        ),
+        # TODO @aignas 2024-07-15: consider supporting this way of specifying
+        # the requirements without the need of the `python_version` attribute
+        # setting. However, this might need more tweaks, hence only leaving a
+        # comment in the test.
+        # requirements_files_by_platform(
+        #     requirements_by_platform = {
+        #         "requirements_lock": "cp311_*",
+        #     },
+        # ),
+    ]:
+        env.expect.that_dict(got).contains_exactly({
+            "requirements_lock": [
+                "cp311_linux_aarch64",
+                "cp311_linux_arm",
+                "cp311_linux_ppc",
+                "cp311_linux_s390x",
+                "cp311_linux_x86_64",
+                "cp311_osx_aarch64",
+                "cp311_osx_x86_64",
+                "cp311_windows_x86_64",
+            ],
+        })
+
+_tests.append(_test_simple_with_python_version)
+
+def _test_multi_os(env):
+    for got in [
+        requirements_files_by_platform(
+            requirements_linux = "requirements_linux",
+            requirements_osx = "requirements_osx",
+            requirements_windows = "requirements_windows",
+        ),
+        requirements_files_by_platform(
+            requirements_by_platform = {
+                "requirements_linux": "linux_*",
+                "requirements_osx": "osx_*",
+                "requirements_windows": "windows_*",
+            },
+        ),
+    ]:
+        env.expect.that_dict(got).contains_exactly({
+            "requirements_linux": [
+                "linux_aarch64",
+                "linux_arm",
+                "linux_ppc",
+                "linux_s390x",
+                "linux_x86_64",
+            ],
+            "requirements_osx": [
+                "osx_aarch64",
+                "osx_x86_64",
+            ],
+            "requirements_windows": [
+                "windows_x86_64",
+            ],
+        })
+
+_tests.append(_test_multi_os)
+
+def _test_multi_os_download_only_platform(env):
+    got = requirements_files_by_platform(
+        requirements_lock = "requirements_linux",
+        extra_pip_args = [
+            "--platform",
+            "manylinux_2_27_x86_64",
+            "--platform=manylinux_2_12_x86_64",
+            "--platform manylinux_2_5_x86_64",
+        ],
+    )
+    env.expect.that_dict(got).contains_exactly({
+        "requirements_linux": ["linux_x86_64"],
+    })
+
+_tests.append(_test_multi_os_download_only_platform)
+
+def _test_os_arch_requirements_with_default(env):
+    got = requirements_files_by_platform(
+        requirements_by_platform = {
+            "requirements_exotic": "linux_super_exotic",
+            "requirements_linux": "linux_x86_64,linux_aarch64",
+        },
+        requirements_lock = "requirements_lock",
+    )
+    env.expect.that_dict(got).contains_exactly({
+        "requirements_exotic": ["linux_super_exotic"],
+        "requirements_linux": ["linux_x86_64", "linux_aarch64"],
+        "requirements_lock": [
+            "linux_arm",
+            "linux_ppc",
+            "linux_s390x",
+            "osx_aarch64",
+            "osx_x86_64",
+            "windows_x86_64",
+        ],
+    })
+
+_tests.append(_test_os_arch_requirements_with_default)
+
+def requirements_files_by_platform_test_suite(name):
+    """Create the test suite.
+
+    Args:
+        name: the name of the test suite
+    """
+    test_suite(name = name, basic_tests = _tests)