fix: use runfiles strategy to get runfiles root
diff --git a/python/runfiles/runfiles.py b/python/runfiles/runfiles.py
index ffa2473..1b3e8eb 100644
--- a/python/runfiles/runfiles.py
+++ b/python/runfiles/runfiles.py
@@ -95,6 +95,9 @@
raise TypeError()
self._runfiles_root = path
+ def _GetRunfilesDir(self) -> str:
+ return self._runfiles_root
+
def RlocationChecked(self, path: str) -> str:
# Use posixpath instead of os.path, because Bazel only creates a runfiles
# tree on Unix platforms, so `Create()` will only create a directory-based
@@ -118,7 +121,7 @@
def __init__(self, strategy: Union[_ManifestBased, _DirectoryBased]) -> None:
self._strategy = strategy
- self._python_runfiles_root = _FindPythonRunfilesRoot()
+ self._python_runfiles_root = self._strategy._GetRunfilesDir()
self._repo_mapping = _ParseRepoMapping(
strategy.RlocationChecked("_repo_mapping")
)
@@ -321,19 +324,6 @@
# Support legacy imports by defining a private symbol.
_Runfiles = Runfiles
-
-def _FindPythonRunfilesRoot() -> str:
- """Finds the root of the Python runfiles tree."""
- root = __file__
- # Walk up our own runfiles path to the root of the runfiles tree from which
- # the current file is being run. This path coincides with what the Bazel
- # Python stub sets up as sys.path[0]. Since that entry can be changed at
- # runtime, we rederive it here.
- for _ in range("rules_python/python/runfiles/runfiles.py".count("/") + 1):
- root = os.path.dirname(root)
- return root
-
-
def _ParseRepoMapping(repo_mapping_path: Optional[str]) -> Dict[Tuple[str, str], str]:
"""Parses the repository mapping manifest."""
# If the repository mapping file can't be found, that is not an error: We
diff --git a/tests/runfiles/runfiles_test.py b/tests/runfiles/runfiles_test.py
index 03350f3..c6f454f 100644
--- a/tests/runfiles/runfiles_test.py
+++ b/tests/runfiles/runfiles_test.py
@@ -527,7 +527,7 @@
expected = ""
else:
expected = "rules_python"
- r = runfiles.Create({"RUNFILES_DIR": "whatever"})
+ r = runfiles.Create({"RUNFILES_DIR": os.environ.get("RUNFILES_DIR")})
assert r is not None # mypy doesn't understand the unittest api.
self.assertEqual(r.CurrentRepository(), expected)