feat: add interpreter_version_info to py_runtime (#1671)

Adds an `interpreter_version_info` attribute to the `py_runtime` and
associated provider that maps to the `sys.version_info` values. This
allows the version of the interpreter to be known statically, which can
be useful for rule sets that depend on the interpreter, and need to
build environments / pathing that contain version info (virtualenvs for
example).

---------

Co-authored-by: Richard Levasseur <rlevasseur@google.com>
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d38a7a4..b861902 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -53,6 +53,13 @@
   method to make imports more ergonomic. Users should only need to import the
   `Runfiles` object to locate runfiles.
 
+* (toolchains) `PyRuntimeInfo` now includes a `interpreter_version_info` field
+  that contains the static version information for the given interpreter.
+  This can be set via `py_runtime` when registering an interpreter toolchain,
+  and will done automatically for the builtin interpreter versions registered via
+  `python_register_toolchains`.
+  Note that this only available on the Starlark implementation of the provider.
+
 ## [0.28.0] - 2024-01-07
 
 [0.28.0]: https://github.com/bazelbuild/rules_python/releases/tag/0.28.0
diff --git a/python/private/common/providers.bzl b/python/private/common/providers.bzl
index 38a7054..f36a2d1 100644
--- a/python/private/common/providers.bzl
+++ b/python/private/common/providers.bzl
@@ -45,7 +45,8 @@
         coverage_files = None,
         python_version,
         stub_shebang = None,
-        bootstrap_template = None):
+        bootstrap_template = None,
+        interpreter_version_info = None):
     if (interpreter_path and interpreter) or (not interpreter_path and not interpreter):
         fail("exactly one of interpreter or interpreter_path must be specified")
 
@@ -82,6 +83,24 @@
     if not stub_shebang:
         stub_shebang = DEFAULT_STUB_SHEBANG
 
+    if interpreter_version_info:
+        if not ("major" in interpreter_version_info and "minor" in interpreter_version_info):
+            fail("interpreter_version_info must have at least two keys, 'major' and 'minor'")
+
+        _interpreter_version_info = dict(**interpreter_version_info)
+        interpreter_version_info = struct(
+            major = int(_interpreter_version_info.pop("major")),
+            minor = int(_interpreter_version_info.pop("minor")),
+            micro = int(_interpreter_version_info.pop("micro")) if "micro" in _interpreter_version_info else None,
+            releaselevel = str(_interpreter_version_info.pop("releaselevel")) if "releaselevel" in _interpreter_version_info else None,
+            serial = int(_interpreter_version_info.pop("serial")) if "serial" in _interpreter_version_info else None,
+        )
+
+        if len(_interpreter_version_info.keys()) > 0:
+            fail("unexpected keys {} in interpreter_version_info".format(
+                str(_interpreter_version_info.keys()),
+            ))
+
     return {
         "bootstrap_template": bootstrap_template,
         "coverage_files": coverage_files,
@@ -89,6 +108,7 @@
         "files": files,
         "interpreter": interpreter,
         "interpreter_path": interpreter_path,
+        "interpreter_version_info": interpreter_version_info,
         "python_version": python_version,
         "stub_shebang": stub_shebang,
     }
@@ -136,6 +156,18 @@
             "filesystem path to the interpreter on the target platform. " +
             "Otherwise, this is `None`."
         ),
+        "interpreter_version_info": (
+            "Version information about the interpreter this runtime provides. " +
+            "It should match the format given by `sys.version_info`, however " +
+            "for simplicity, the micro, releaselevel, and serial values are " +
+            "optional." +
+            "A struct with the following fields:\n" +
+            "  * major: int, the major version number\n" +
+            "  * minor: int, the minor version number\n" +
+            "  * micro: optional int, the micro version number\n" +
+            "  * releaselevel: optional str, the release level\n" +
+            "  * serial: optional int, the serial number of the release"
+        ),
         "python_version": (
             "Indicates whether this runtime uses Python major version 2 or 3. " +
             "Valid values are (only) `\"PY2\"` and " +
diff --git a/python/private/common/py_runtime_rule.bzl b/python/private/common/py_runtime_rule.bzl
index 9d53543..190158b 100644
--- a/python/private/common/py_runtime_rule.bzl
+++ b/python/private/common/py_runtime_rule.bzl
@@ -79,6 +79,8 @@
 
     python_version = ctx.attr.python_version
 
+    interpreter_version_info = ctx.attr.interpreter_version_info
+
     # TODO: Uncomment this after --incompatible_python_disable_py2 defaults to true
     # if ctx.fragments.py.disable_py2 and python_version == "PY2":
     #     fail("Using Python 2 is not supported and disabled; see " +
@@ -93,10 +95,16 @@
         python_version = python_version,
         stub_shebang = ctx.attr.stub_shebang,
         bootstrap_template = ctx.file.bootstrap_template,
+        interpreter_version_info = interpreter_version_info,
     )
     builtin_py_runtime_info_kwargs = dict(py_runtime_info_kwargs)
+
+    # Pop this property as it does not exist on BuiltinPyRuntimeInfo
+    builtin_py_runtime_info_kwargs.pop("interpreter_version_info")
+
     if not IS_BAZEL_7_OR_HIGHER:
         builtin_py_runtime_info_kwargs.pop("bootstrap_template")
+
     return [
         PyRuntimeInfo(**py_runtime_info_kwargs),
         # Return the builtin provider for better compatibility.
@@ -232,6 +240,19 @@
 For a platform runtime, this is the absolute path of a Python interpreter on
 the target platform. For an in-build runtime this attribute must not be set.
 """),
+        "interpreter_version_info": attr.string_dict(
+            doc = """
+Version information about the interpreter this runtime provides. The
+supported keys match the names for `sys.version_info`. While the input
+values are strings, most are converted to ints. The supported keys are:
+  * major: int, the major version number
+  * minor: int, the minor version number
+  * micro: optional int, the micro version number
+  * releaselevel: optional str, the release level
+  * serial: optional int, the serial number of the release"
+            """,
+            mandatory = False,
+        ),
         "python_version": attr.string(
             default = "PY3",
             values = ["PY2", "PY3"],
diff --git a/python/repositories.bzl b/python/repositories.bzl
index 7422a50..bfba86d 100644
--- a/python/repositories.bzl
+++ b/python/repositories.bzl
@@ -102,7 +102,8 @@
 
     platform = rctx.attr.platform
     python_version = rctx.attr.python_version
-    python_short_version = python_version.rpartition(".")[0]
+    python_version_info = python_version.split(".")
+    python_short_version = "{0}.{1}".format(*python_version_info)
     release_filename = rctx.attr.release_filename
     urls = rctx.attr.urls or [rctx.attr.url]
     auth = get_auth(rctx, urls)
@@ -335,6 +336,11 @@
     files = [":files"],
 {coverage_attr}
     interpreter = "{python_path}",
+    interpreter_version_info = {{
+        "major": "{interpreter_version_info_major}",
+        "minor": "{interpreter_version_info_minor}",
+        "micro": "{interpreter_version_info_micro}",
+    }},
     python_version = "PY3",
 )
 
@@ -356,6 +362,9 @@
         python_version = python_short_version,
         python_version_nodot = python_short_version.replace(".", ""),
         coverage_attr = coverage_attr_text,
+        interpreter_version_info_major = python_version_info[0],
+        interpreter_version_info_minor = python_version_info[1],
+        interpreter_version_info_micro = python_version_info[2],
     )
     rctx.delete("python")
     rctx.symlink(python_bin, "python")
diff --git a/tests/py_runtime/py_runtime_tests.bzl b/tests/py_runtime/py_runtime_tests.bzl
index 9fa5e2a..b47923d 100644
--- a/tests/py_runtime/py_runtime_tests.bzl
+++ b/tests/py_runtime/py_runtime_tests.bzl
@@ -413,6 +413,121 @@
 
 _tests.append(_test_system_interpreter_must_be_absolute)
 
+def _interpreter_version_info_test(name, interpreter_version_info, impl, expect_failure = True):
+    if config.enable_pystar:
+        py_runtime_kwargs = {
+            "interpreter_version_info": interpreter_version_info,
+        }
+        attr_values = {}
+    else:
+        py_runtime_kwargs = {}
+        attr_values = _SKIP_TEST
+
+    rt_util.helper_target(
+        py_runtime,
+        name = name + "_subject",
+        python_version = "PY3",
+        interpreter_path = "/py",
+        **py_runtime_kwargs
+    )
+    analysis_test(
+        name = name,
+        target = name + "_subject",
+        impl = impl,
+        expect_failure = expect_failure,
+        attr_values = attr_values,
+    )
+
+def _test_interpreter_version_info_must_define_major_and_minor_only_major(name):
+    _interpreter_version_info_test(
+        name,
+        {
+            "major": "3",
+        },
+        lambda env, target: (
+            env.expect.that_target(target).failures().contains_predicate(
+                matching.str_matches("must have at least two keys, 'major' and 'minor'"),
+            )
+        ),
+    )
+
+_tests.append(_test_interpreter_version_info_must_define_major_and_minor_only_major)
+
+def _test_interpreter_version_info_must_define_major_and_minor_only_minor(name):
+    _interpreter_version_info_test(
+        name,
+        {
+            "minor": "3",
+        },
+        lambda env, target: (
+            env.expect.that_target(target).failures().contains_predicate(
+                matching.str_matches("must have at least two keys, 'major' and 'minor'"),
+            )
+        ),
+    )
+
+_tests.append(_test_interpreter_version_info_must_define_major_and_minor_only_minor)
+
+def _test_interpreter_version_info_no_extraneous_keys(name):
+    _interpreter_version_info_test(
+        name,
+        {
+            "major": "3",
+            "minor": "3",
+            "something": "foo",
+        },
+        lambda env, target: (
+            env.expect.that_target(target).failures().contains_predicate(
+                matching.str_matches("unexpected keys [\"something\"]"),
+            )
+        ),
+    )
+
+_tests.append(_test_interpreter_version_info_no_extraneous_keys)
+
+def _test_interpreter_version_info_sets_values_to_none_if_not_given(name):
+    _interpreter_version_info_test(
+        name,
+        {
+            "major": "3",
+            "micro": "10",
+            "minor": "3",
+        },
+        lambda env, target: (
+            env.expect.that_target(target).provider(
+                PyRuntimeInfo,
+                factory = py_runtime_info_subject,
+            ).interpreter_version_info().serial().equals(None)
+        ),
+        expect_failure = False,
+    )
+
+_tests.append(_test_interpreter_version_info_sets_values_to_none_if_not_given)
+
+def _test_interpreter_version_info_parses_values_to_struct(name):
+    _interpreter_version_info_test(
+        name,
+        {
+            "major": "3",
+            "micro": "10",
+            "minor": "6",
+            "releaselevel": "alpha",
+            "serial": "1",
+        },
+        impl = _test_interpreter_version_info_parses_values_to_struct_impl,
+        expect_failure = False,
+    )
+
+def _test_interpreter_version_info_parses_values_to_struct_impl(env, target):
+    version_info = env.expect.that_target(target).provider(PyRuntimeInfo, factory = py_runtime_info_subject).interpreter_version_info()
+    version_info.major().equals(3)
+    version_info.minor().equals(6)
+    version_info.micro().equals(10)
+    version_info.releaselevel().equals("alpha")
+    version_info.serial().equals(1)
+
+_tests.append(_test_interpreter_version_info_parses_values_to_struct)
+
 def py_runtime_test_suite(name):
     test_suite(
         name = name,
diff --git a/tests/py_runtime_info_subject.bzl b/tests/py_runtime_info_subject.bzl
index 219719f..541d4d9 100644
--- a/tests/py_runtime_info_subject.bzl
+++ b/tests/py_runtime_info_subject.bzl
@@ -38,6 +38,7 @@
         files = lambda *a, **k: _py_runtime_info_subject_files(self, *a, **k),
         interpreter = lambda *a, **k: _py_runtime_info_subject_interpreter(self, *a, **k),
         interpreter_path = lambda *a, **k: _py_runtime_info_subject_interpreter_path(self, *a, **k),
+        interpreter_version_info = lambda *a, **k: _py_runtime_info_subject_interpreter_version_info(self, *a, **k),
         python_version = lambda *a, **k: _py_runtime_info_subject_python_version(self, *a, **k),
         stub_shebang = lambda *a, **k: _py_runtime_info_subject_stub_shebang(self, *a, **k),
         # go/keep-sorted end
@@ -100,3 +101,16 @@
         self.actual.stub_shebang,
         meta = self.meta.derive("stub_shebang()"),
     )
+
+def _py_runtime_info_subject_interpreter_version_info(self):
+    return subjects.struct(
+        self.actual.interpreter_version_info,
+        attrs = dict(
+            major = subjects.int,
+            minor = subjects.int,
+            micro = subjects.int,
+            releaselevel = subjects.str,
+            serial = subjects.int,
+        ),
+        meta = self.meta.derive("interpreter_version_info()"),
+    )