refactor(pip_install): make the API surface of deps parsing smaller (#1657)

With this change we also fix a minor bug where the self edge
dependencies
within the wheels are correctly handled as per unit tests.

Also refactor the code to use the fact that the selects are chosen
based on the specificity of the condition. This can make the addition
of deps incrementally correct without the need for extra processing in
the `build` method. This should hopefully make the implementation more
future-proof as it makes use of these generic concepts and is likely to
help to extend this code to work with selecting on multiple Python
versions.

I have added a test to ensure coverage of the new code.
diff --git a/python/pip_install/tools/wheel_installer/wheel.py b/python/pip_install/tools/wheel_installer/wheel.py
index efd916d..a67756e 100644
--- a/python/pip_install/tools/wheel_installer/wheel.py
+++ b/python/pip_install/tools/wheel_installer/wheel.py
@@ -22,7 +22,7 @@
 from dataclasses import dataclass
 from enum import Enum
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
 
 import installer
 from packaging.requirements import Requirement
@@ -86,6 +86,17 @@
             )
         ]
 
+    def all_specializations(self) -> Iterator["Platform"]:
+        """Return the platform itself and all its unambiguous specializations.
+
+        For more info about specializations see
+        https://bazel.build/docs/configurable-attributes
+        """
+        yield self
+        if self.arch is None:
+            for arch in Arch:
+                yield Platform(os=self.os, arch=arch)
+
     def __lt__(self, other: Any) -> bool:
         """Add a comparison method, so that `sorted` returns the most specialized platforms first."""
         if not isinstance(other, Platform) or other is None:
@@ -228,57 +239,94 @@
     def __init__(
         self,
         name: str,
+        *,
+        requires_dist: Optional[List[str]],
         extras: Optional[Set[str]] = None,
         platforms: Optional[Set[Platform]] = None,
     ):
         self.name: str = Deps._normalize(name)
+        self._platforms: Set[Platform] = platforms or set()
+
+        # Sort so that the dictionary order in the FrozenDeps is deterministic
+        # without the final sort because Python retains insertion order. That way
+        # the sorting by platform is limited within the Platform class itself and
+        # the unit-tests for the Deps can be simpler.
+        reqs = sorted(
+            (Requirement(wheel_req) for wheel_req in requires_dist),
+            key=lambda x: f"{x.name}:{sorted(x.extras)}",
+        )
+
+        want_extras = self._resolve_extras(reqs, extras)
+
+        # Then add all of the requirements in order
         self._deps: Set[str] = set()
         self._select: Dict[Platform, Set[str]] = defaultdict(set)
-        self._want_extras: Set[str] = extras or {""}  # empty strings means no extras
-        self._platforms: Set[Platform] = platforms or set()
+        for req in reqs:
+            self._add_req(req, want_extras)
 
     def _add(self, dep: str, platform: Optional[Platform]):
         dep = Deps._normalize(dep)
 
-        # Packages may create dependency cycles when specifying optional-dependencies / 'extras'.
-        # Example: github.com/google/etils/blob/a0b71032095db14acf6b33516bca6d885fe09e35/pyproject.toml#L32.
+        # Self-edges are processed in _resolve_extras
         if dep == self.name:
             return
 
-        if platform:
-            self._select[platform].add(dep)
-        else:
+        if not platform:
             self._deps.add(dep)
+            return
+
+        # Add the platform-specific dep
+        self._select[platform].add(dep)
+
+        # Add the dep to specializations of the given platform if they
+        # exist in the select statement.
+        for p in platform.all_specializations():
+            if p not in self._select:
+                continue
+
+            self._select[p].add(dep)
+
+        if len(self._select[platform]) != 1:
+            return
+
+        # We are adding a new item to the select and we need to ensure that
+        # existing dependencies from less specialized platforms are propagated
+        # to the newly added dependency set.
+        for p, deps in self._select.items():
+            # Check if the existing platform overlaps with the given platform
+            if p == platform or platform not in p.all_specializations():
+                continue
+
+            self._select[platform].update(self._select[p])
 
     @staticmethod
     def _normalize(name: str) -> str:
         return re.sub(r"[-_.]+", "_", name).lower()
 
-    def add(self, *wheel_reqs: str) -> None:
-        reqs = [Requirement(wheel_req) for wheel_req in wheel_reqs]
-
-        # Resolve any extra extras due to self-edges
-        self._want_extras = self._resolve_extras(reqs)
-
-        # process self-edges first to resolve the extras used
-        for req in reqs:
-            self._add_req(req)
-
-    def _resolve_extras(self, reqs: List[Requirement]) -> Set[str]:
+    def _resolve_extras(
+        self, reqs: List[Requirement], extras: Optional[Set[str]]
+    ) -> Set[str]:
         """Resolve extras which are due to depending on self[some_other_extra].
 
         Some packages may have cyclic dependencies resulting from extras being used, one example is
-        `elint`, where we have one set of extras as aliases for other extras
+        `etils`, where we have one set of extras as aliases for other extras
         and we have an extra called 'all' that includes all other extras.
 
+        Example: github.com/google/etils/blob/a0b71032095db14acf6b33516bca6d885fe09e35/pyproject.toml#L32.
+
         When the `requirements.txt` is generated by `pip-tools`, then it is likely that
         this step is not needed, but for other `requirements.txt` files this may be useful.
 
-        NOTE @aignas 2023-12-08: the extra resolution is not platform dependent, but
-        in order for it to become platform dependent we would have to have separate targets for each extra in
-        self._want_extras.
+        NOTE @aignas 2023-12-08: the extra resolution is not platform dependent,
+        but in order for it to become platform dependent we would have to have
+        separate targets for each extra in extras.
         """
-        extras = self._want_extras
+
+        # Resolve any extra extras due to self-edges, empty string means no
+        # extras The empty string in the set is just a way to make the handling
+        # of no extras and a single extra easier and having a set of {"", "foo"}
+        # is equivalent to having {"foo"}.
+        extras = extras or {""}
 
         self_reqs = []
         for req in reqs:
@@ -311,29 +359,34 @@
 
         return extras
 
-    def _add_req(self, req: Requirement) -> None:
-        extras = self._want_extras
-
+    def _add_req(self, req: Requirement, extras: Set[str]) -> None:
         if req.marker is None:
             self._add(req.name, None)
             return
 
         marker_str = str(req.marker)
 
+        if not self._platforms:
+            if any(req.marker.evaluate({"extra": extra}) for extra in extras):
+                self._add(req.name, None)
+            return
+
         # NOTE @aignas 2023-12-08: in order to have reasonable select statements
         # we do have to have some parsing of the markers, so it begs the question
         # if packaging should be reimplemented in Starlark to have the best solution
         # for now we will implement it in Python and see what the best parsing result
         # can be before making this decision.
-        if not self._platforms or not any(
+        match_os = any(
             tag in marker_str
             for tag in [
                 "os_name",
                 "sys_platform",
-                "platform_machine",
                 "platform_system",
             ]
-        ):
+        )
+        match_arch = "platform_machine" in marker_str
+
+        if not (match_os or match_arch):
             if any(req.marker.evaluate({"extra": extra}) for extra in extras):
                 self._add(req.name, None)
             return
@@ -344,34 +397,15 @@
             ):
                 continue
 
-            if "platform_machine" in marker_str:
+            if match_arch:
                 self._add(req.name, plat)
             else:
                 self._add(req.name, Platform(plat.os))
 
     def build(self) -> FrozenDeps:
-        if not self._select:
-            return FrozenDeps(
-                deps=sorted(self._deps),
-                deps_select={},
-            )
-
-        # Get all of the OS-specific dependencies applicable to all architectures
-        select = {
-            p: deps for p, deps in self._select.items() if deps and p.arch is None
-        }
-        # Now add them to all arch specific dependencies
-        select.update(
-            {
-                p: deps | select.get(Platform(p.os), set())
-                for p, deps in self._select.items()
-                if deps and p.arch is not None
-            }
-        )
-
         return FrozenDeps(
             deps=sorted(self._deps),
-            deps_select={str(p): sorted(deps) for p, deps in sorted(select.items())},
+            deps_select={str(p): sorted(deps) for p, deps in self._select.items()},
         )
 
 
@@ -429,15 +463,12 @@
         extras_requested: Set[str] = None,
         platforms: Optional[Set[Platform]] = None,
     ) -> FrozenDeps:
-        dependency_set = Deps(
+        return Deps(
             self.name,
             extras=extras_requested,
             platforms=platforms,
-        )
-        for wheel_req in self.metadata.get_all("Requires-Dist", []):
-            dependency_set.add(wheel_req)
-
-        return dependency_set.build()
+            requires_dist=self.metadata.get_all("Requires-Dist", []),
+        ).build()
 
     def unzip(self, directory: str) -> None:
         installation_schemes = {
diff --git a/python/pip_install/tools/wheel_installer/wheel_test.py b/python/pip_install/tools/wheel_installer/wheel_test.py
index 721b710..5f4b8c8 100644
--- a/python/pip_install/tools/wheel_installer/wheel_test.py
+++ b/python/pip_install/tools/wheel_installer/wheel_test.py
@@ -5,8 +5,7 @@
 
 class DepsTest(unittest.TestCase):
     def test_simple(self):
-        deps = wheel.Deps("foo")
-        deps.add("bar")
+        deps = wheel.Deps("foo", requires_dist=["bar"])
 
         got = deps.build()
 
@@ -18,13 +17,18 @@
         platforms = {
             "linux_x86_64",
             "osx_x86_64",
+            "osx_aarch64",
             "windows_x86_64",
         }
-        deps = wheel.Deps("foo", platforms=set(wheel.Platform.from_string(platforms)))
-        deps.add(
-            "bar",
-            "posix_dep; os_name=='posix'",
-            "win_dep; os_name=='nt'",
+        deps = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "bar",
+                "an_osx_dep; sys_platform=='darwin'",
+                "posix_dep; os_name=='posix'",
+                "win_dep; os_name=='nt'",
+            ],
+            platforms=set(wheel.Platform.from_string(platforms)),
         )
 
         got = deps.build()
@@ -33,36 +37,56 @@
         self.assertEqual(
             {
                 "@platforms//os:linux": ["posix_dep"],
-                "@platforms//os:osx": ["posix_dep"],
+                "@platforms//os:osx": ["an_osx_dep", "posix_dep"],
                 "@platforms//os:windows": ["win_dep"],
             },
             got.deps_select,
         )
 
-    def test_can_add_platform_specific_deps(self):
+    def test_deps_are_added_to_more_specialized_platforms(self):
         platforms = {
-            "linux_x86_64",
             "osx_x86_64",
             "osx_aarch64",
-            "windows_x86_64",
         }
-        deps = wheel.Deps("foo", platforms=set(wheel.Platform.from_string(platforms)))
-        deps.add(
-            "bar",
-            "posix_dep; os_name=='posix'",
-            "m1_dep; sys_platform=='darwin' and platform_machine=='arm64'",
-            "win_dep; os_name=='nt'",
+        got = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "m1_dep; sys_platform=='darwin' and platform_machine=='arm64'",
+                "mac_dep; sys_platform=='darwin'",
+            ],
+            platforms=set(wheel.Platform.from_string(platforms)),
+        ).build()
+
+        self.assertEqual(
+            wheel.FrozenDeps(
+                deps=[],
+                deps_select={
+                    "osx_aarch64": ["m1_dep", "mac_dep"],
+                    "@platforms//os:osx": ["mac_dep"],
+                },
+            ),
+            got,
         )
 
-        got = deps.build()
+    def test_deps_from_more_specialized_platforms_are_propagated(self):
+        platforms = {
+            "osx_x86_64",
+            "osx_aarch64",
+        }
+        got = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "a_mac_dep; sys_platform=='darwin'",
+                "m1_dep; sys_platform=='darwin' and platform_machine=='arm64'",
+            ],
+            platforms=set(wheel.Platform.from_string(platforms)),
+        ).build()
 
-        self.assertEqual(["bar"], got.deps)
+        self.assertEqual([], got.deps)
         self.assertEqual(
             {
-                "osx_aarch64": ["m1_dep", "posix_dep"],
-                "@platforms//os:linux": ["posix_dep"],
-                "@platforms//os:osx": ["posix_dep"],
-                "@platforms//os:windows": ["win_dep"],
+                "osx_aarch64": ["a_mac_dep", "m1_dep"],
+                "@platforms//os:osx": ["a_mac_dep"],
             },
             got.deps_select,
         )
@@ -74,14 +98,15 @@
             "osx_aarch64",
             "windows_x86_64",
         }
-        deps = wheel.Deps("foo", platforms=set(wheel.Platform.from_string(platforms)))
-        deps.add(
-            "bar",
-            "baz; implementation_name=='cpython'",
-            "m1_dep; sys_platform=='darwin' and platform_machine=='arm64'",
-        )
-
-        got = deps.build()
+        got = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "bar",
+                "baz; implementation_name=='cpython'",
+                "m1_dep; sys_platform=='darwin' and platform_machine=='arm64'",
+            ],
+            platforms=set(wheel.Platform.from_string(platforms)),
+        ).build()
 
         self.assertEqual(["bar", "baz"], got.deps)
         self.assertEqual(
@@ -92,12 +117,15 @@
         )
 
     def test_self_is_ignored(self):
-        deps = wheel.Deps("foo", extras={"ssl"})
-        deps.add(
-            "bar",
-            "req_dep; extra == 'requests'",
-            "foo[requests]; extra == 'ssl'",
-            "ssl_lib; extra == 'ssl'",
+        deps = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "bar",
+                "req_dep; extra == 'requests'",
+                "foo[requests]; extra == 'ssl'",
+                "ssl_lib; extra == 'ssl'",
+            ],
+            extras={"ssl"},
         )
 
         got = deps.build()
@@ -105,96 +133,22 @@
         self.assertEqual(["bar", "req_dep", "ssl_lib"], got.deps)
         self.assertEqual({}, got.deps_select)
 
-    def test_handle_etils(self):
-        deps = wheel.Deps("etils", extras={"all"})
-        requires = """
-etils[array-types] ; extra == "all"
-etils[eapp] ; extra == "all"
-etils[ecolab] ; extra == "all"
-etils[edc] ; extra == "all"
-etils[enp] ; extra == "all"
-etils[epath] ; extra == "all"
-etils[epath-gcs] ; extra == "all"
-etils[epath-s3] ; extra == "all"
-etils[epy] ; extra == "all"
-etils[etqdm] ; extra == "all"
-etils[etree] ; extra == "all"
-etils[etree-dm] ; extra == "all"
-etils[etree-jax] ; extra == "all"
-etils[etree-tf] ; extra == "all"
-etils[enp] ; extra == "array-types"
-pytest ; extra == "dev"
-pytest-subtests ; extra == "dev"
-pytest-xdist ; extra == "dev"
-pyink ; extra == "dev"
-pylint>=2.6.0 ; extra == "dev"
-chex ; extra == "dev"
-torch ; extra == "dev"
-optree ; extra == "dev"
-dataclass_array ; extra == "dev"
-sphinx-apitree[ext] ; extra == "docs"
-etils[dev,all] ; extra == "docs"
-absl-py ; extra == "eapp"
-simple_parsing ; extra == "eapp"
-etils[epy] ; extra == "eapp"
-jupyter ; extra == "ecolab"
-numpy ; extra == "ecolab"
-mediapy ; extra == "ecolab"
-packaging ; extra == "ecolab"
-etils[enp] ; extra == "ecolab"
-etils[epy] ; extra == "ecolab"
-etils[epy] ; extra == "edc"
-numpy ; extra == "enp"
-etils[epy] ; extra == "enp"
-fsspec ; extra == "epath"
-importlib_resources ; extra == "epath"
-typing_extensions ; extra == "epath"
-zipp ; extra == "epath"
-etils[epy] ; extra == "epath"
-gcsfs ; extra == "epath-gcs"
-etils[epath] ; extra == "epath-gcs"
-s3fs ; extra == "epath-s3"
-etils[epath] ; extra == "epath-s3"
-typing_extensions ; extra == "epy"
-absl-py ; extra == "etqdm"
-tqdm ; extra == "etqdm"
-etils[epy] ; extra == "etqdm"
-etils[array_types] ; extra == "etree"
-etils[epy] ; extra == "etree"
-etils[enp] ; extra == "etree"
-etils[etqdm] ; extra == "etree"
-dm-tree ; extra == "etree-dm"
-etils[etree] ; extra == "etree-dm"
-jax[cpu] ; extra == "etree-jax"
-etils[etree] ; extra == "etree-jax"
-tensorflow ; extra == "etree-tf"
-etils[etree] ; extra == "etree-tf"
-etils[ecolab] ; extra == "lazy-imports"
-"""
-
-        deps.add(*requires.strip().split("\n"))
+    def test_self_dependencies_can_come_in_any_order(self):
+        deps = wheel.Deps(
+            "foo",
+            requires_dist=[
+                "bar",
+                "baz; extra == 'feat'",
+                "foo[feat2]; extra == 'all'",
+                "foo[feat]; extra == 'feat2'",
+                "zdep; extra == 'all'",
+            ],
+            extras={"all"},
+        )
 
         got = deps.build()
-        want = [
-            "absl_py",
-            "dm_tree",
-            "fsspec",
-            "gcsfs",
-            "importlib_resources",
-            "jax",
-            "jupyter",
-            "mediapy",
-            "numpy",
-            "packaging",
-            "s3fs",
-            "simple_parsing",
-            "tensorflow",
-            "tqdm",
-            "typing_extensions",
-            "zipp",
-        ]
 
-        self.assertEqual(want, got.deps)
+        self.assertEqual(["bar", "baz", "zdep"], got.deps)
         self.assertEqual({}, got.deps_select)