chore: use python.defaults to set rules_python default python version (#3301)

python.defaults is the modern way to set the default.

Setting it this way also helps avoid a bug where if a root module has a
single `python.toolchain()` call (which are implicitly treated as
`is_default=True`) and also sets the default using `python.defaults()`,
some validation logic gives an error about using both ways to set a
default.
diff --git a/MODULE.bazel b/MODULE.bazel
index 6251ed4..a29a898 100644
--- a/MODULE.bazel
+++ b/MODULE.bazel
@@ -42,8 +42,12 @@
 # NOTE: This is not a stable version. It is provided for convenience, but will
 # change frequently to track the most recent Python version.
 # NOTE: The root module can override this.
+# NOTE: There must be a corresponding `python.toolchain()` call for the version
+# specified here.
+python.defaults(
+    python_version = "3.11",
+)
 python.toolchain(
-    is_default = True,
     python_version = "3.11",
 )
 use_repo(
diff --git a/python/private/python.bzl b/python/private/python.bzl
index faad53f..a1fe80e 100644
--- a/python/private/python.bzl
+++ b/python/private/python.bzl
@@ -85,46 +85,7 @@
 
     config = _get_toolchain_config(modules = module_ctx.modules, _fail = _fail)
 
-    default_python_version = None
-    for mod in module_ctx.modules:
-        defaults_attr_structs = _create_defaults_attr_structs(mod = mod)
-        default_python_version_env = None
-        default_python_version_file = None
-
-        # Only the root module and rules_python are allowed to specify the default
-        # toolchain for a couple reasons:
-        # * It prevents submodules from specifying different defaults and only
-        #   one of them winning.
-        # * rules_python needs to set a soft default in case the root module doesn't,
-        #   e.g. if the root module doesn't use Python itself.
-        # * The root module is allowed to override the rules_python default.
-        if mod.is_root or (mod.name == "rules_python" and not default_python_version):
-            for defaults_attr in defaults_attr_structs:
-                default_python_version = _one_or_the_same(
-                    default_python_version,
-                    defaults_attr.python_version,
-                    onerror = _fail_multiple_defaults_python_version,
-                )
-                default_python_version_env = _one_or_the_same(
-                    default_python_version_env,
-                    defaults_attr.python_version_env,
-                    onerror = _fail_multiple_defaults_python_version_env,
-                )
-                default_python_version_file = _one_or_the_same(
-                    default_python_version_file,
-                    defaults_attr.python_version_file,
-                    onerror = _fail_multiple_defaults_python_version_file,
-                )
-            if default_python_version_file:
-                default_python_version = _one_or_the_same(
-                    default_python_version,
-                    module_ctx.read(default_python_version_file, watch = "yes").strip(),
-                )
-            if default_python_version_env:
-                default_python_version = module_ctx.getenv(
-                    default_python_version_env,
-                    default_python_version,
-                )
+    default_python_version = _compute_default_python_version(module_ctx)
 
     seen_versions = {}
     for mod in module_ctx.modules:
@@ -152,13 +113,7 @@
                 # * rules_python needs to set a soft default in case the root module doesn't,
                 #   e.g. if the root module doesn't use Python itself.
                 # * The root module is allowed to override the rules_python default.
-                if default_python_version:
-                    is_default = default_python_version == toolchain_version
-                    if toolchain_attr.is_default and not is_default:
-                        fail("The 'is_default' attribute doesn't work if you set " +
-                             "the default Python version with the `defaults` tag.")
-                else:
-                    is_default = toolchain_attr.is_default
+                is_default = default_python_version == toolchain_version
 
                 # Also only the root module should be able to decide ignore_root_user_error.
                 # Modules being depended upon don't know the final environment, so they aren't
@@ -169,15 +124,15 @@
                     fail("Toolchains in the root module must have consistent 'ignore_root_user_error' attributes")
 
                 ignore_root_user_error = toolchain_attr.ignore_root_user_error
-            elif mod.name == "rules_python" and not default_toolchain and not default_python_version:
-                # We don't do the len() check because we want the default that rules_python
-                # sets to be clearly visible.
-                is_default = toolchain_attr.is_default
+            elif mod.name == "rules_python" and not default_toolchain:
+                # This branch handles when the root module doesn't declare a
+                # Python toolchain
+                is_default = default_python_version == toolchain_version
             else:
                 is_default = False
 
             if is_default and default_toolchain != None:
-                _fail_multiple_default_toolchains(
+                _fail_multiple_default_toolchains_chosen(
                     first = default_toolchain.name,
                     second = toolchain_name,
                 )
@@ -577,14 +532,24 @@
         second = second,
     ))
 
-def _fail_multiple_default_toolchains(first, second):
+def _fail_multiple_default_toolchains_chosen(first, second):
     fail(("Multiple default toolchains: only one toolchain " +
-          "can have is_default=True. First default " +
+          "can be chosen as a default. First default " +
           "was toolchain '{first}'. Second was '{second}'").format(
         first = first,
         second = second,
     ))
 
+def _fail_multiple_default_toolchains_in_module(mod, toolchain_attrs):
+    fail(("Multiple default toolchains: only one toolchain " +
+          "can have is_default=True.\n" +
+          "Module '{module}' contains {count} toolchains with " +
+          "is_default=True: {versions}").format(
+        module = mod.name,
+        count = len(toolchain_attrs),
+        versions = ", ".join(sorted([v.python_version for v in toolchain_attrs])),
+    ))
+
 def _validate_version(version_str, *, _fail = fail):
     v = version.parse(version_str, strict = True, _fail = _fail)
     if v == None:
@@ -880,6 +845,72 @@
         register_all_versions = register_all_versions,
     )
 
+def _compute_default_python_version(mctx):
+    default_python_version = None
+    for mod in mctx.modules:
+        # Only the root module and rules_python are allowed to specify the default
+        # toolchain for a couple reasons:
+        # * It prevents submodules from specifying different defaults and only
+        #   one of them winning.
+        # * rules_python needs to set a soft default in case the root module doesn't,
+        #   e.g. if the root module doesn't use Python itself.
+        # * The root module is allowed to override the rules_python default.
+        if not (mod.is_root or mod.name == "rules_python"):
+            continue
+
+        defaults_attr_structs = _create_defaults_attr_structs(mod = mod)
+        default_python_version_env = None
+        default_python_version_file = None
+
+        for defaults_attr in defaults_attr_structs:
+            default_python_version = _one_or_the_same(
+                default_python_version,
+                defaults_attr.python_version,
+                onerror = _fail_multiple_defaults_python_version,
+            )
+            default_python_version_env = _one_or_the_same(
+                default_python_version_env,
+                defaults_attr.python_version_env,
+                onerror = _fail_multiple_defaults_python_version_env,
+            )
+            default_python_version_file = _one_or_the_same(
+                default_python_version_file,
+                defaults_attr.python_version_file,
+                onerror = _fail_multiple_defaults_python_version_file,
+            )
+        if default_python_version_file:
+            default_python_version = _one_or_the_same(
+                default_python_version,
+                mctx.read(default_python_version_file, watch = "yes").strip(),
+            )
+        if default_python_version_env:
+            default_python_version = mctx.getenv(
+                default_python_version_env,
+                default_python_version,
+            )
+
+        if default_python_version:
+            break
+
+        # Otherwise, look at legacy python.toolchain() calls for a default
+        toolchain_attrs = mod.tags.toolchain
+
+        # Convenience: if one python.toolchain() call exists, treat it as
+        # the default.
+        if len(toolchain_attrs) == 1:
+            default_python_version = toolchain_attrs[0].python_version
+        else:
+            sets_default = [v for v in toolchain_attrs if v.is_default]
+            if len(sets_default) == 1:
+                default_python_version = sets_default[0].python_version
+            elif len(sets_default) > 1:
+                _fail_multiple_default_toolchains_in_module(mod, toolchain_attrs)
+
+        if default_python_version:
+            break
+
+    return default_python_version
+
 def _create_defaults_attr_structs(*, mod):
     arg_structs = []
 
@@ -915,7 +946,11 @@
 
     return arg_structs
 
-def _create_toolchain_attrs_struct(*, tag = None, python_version = None, toolchain_tag_count = None):
+def _create_toolchain_attrs_struct(
+        *,
+        tag = None,
+        python_version = None,
+        toolchain_tag_count = None):
     if tag and python_version:
         fail("Only one of tag and python version can be specified")
     if tag:
diff --git a/tests/python/python_tests.bzl b/tests/python/python_tests.bzl
index 9081a0e..96d78d1 100644
--- a/tests/python/python_tests.bzl
+++ b/tests/python/python_tests.bzl
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-""
+"""Unit tests for //python/extensions:python.bzl bzlmod extension."""
 
 load("@pythons_hub//:versions.bzl", "MINOR_MAPPING")
 load("@rules_testing//lib:test_suite.bzl", "test_suite")
@@ -43,7 +43,8 @@
         ],
     )
 
-def _mod(*, name, defaults = [], toolchain = [], override = [], single_version_override = [], single_version_platform_override = [], is_root = True):
+# todo: change is_root to false by default. most modules aren't root
+def _mod(*, name, defaults = [], toolchain = [], override = [], single_version_override = [], single_version_platform_override = [], is_root = False):
     return struct(
         name = name,
         tags = struct(
@@ -88,6 +89,15 @@
         register_all_versions = register_all_versions,
     )
 
+def _rules_python_module(is_root = False):
+    """A mock of what the real rules_python MODULE.bazel looks like."""
+    return _mod(
+        name = "rules_python",
+        defaults = [_defaults(python_version = "3.11")],
+        toolchain = [_toolchain("3.11")],
+        is_root = is_root,
+    )
+
 def _single_version_override(
         python_version = "",
         sha256 = {},
@@ -138,10 +148,11 @@
         arch = "",
     )
 
-def _test_default(env):
+def _test_default_from_rules_python_when_rules_python_is_root(env):
+    """Verify that rules_python (as root module) default is applied."""
     py = parse_modules(
         module_ctx = _mock_mctx(
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(is_root = True),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -167,12 +178,13 @@
     )
     env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])
 
-_tests.append(_test_default)
+_tests.append(_test_default_from_rules_python_when_rules_python_is_root)
 
-def _test_default_some_module(env):
+def _test_default_from_rules_python_when_rules_python_is_not_root(env):
+    """Verify that rules_python default applies when rules_python is not the root module."""
     py = parse_modules(
         module_ctx = _mock_mctx(
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False),
+            _rules_python_module(),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -186,12 +198,13 @@
     )
     env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])
 
-_tests.append(_test_default_some_module)
+_tests.append(_test_default_from_rules_python_when_rules_python_is_not_root)
 
 def _test_default_with_patch_version(env):
     py = parse_modules(
         module_ctx = _mock_mctx(
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11.2")]),
+            _mod(name = "alpha", toolchain = [_toolchain("3.11.2")], is_root = True),
+            _rules_python_module(is_root = True),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -203,39 +216,19 @@
         python_version = "3.11.2",
         register_coverage_tool = False,
     )
-    env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])
+    env.expect.that_collection(py.toolchains).contains_at_least([want_toolchain])
 
 _tests.append(_test_default_with_patch_version)
 
-def _test_default_non_rules_python(env):
-    py = parse_modules(
-        module_ctx = _mock_mctx(
-            # NOTE @aignas 2024-09-06: the first item in the module_ctx.modules
-            # could be a non-root module, which is the case if the root module
-            # does not make any calls to the extension.
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False),
-        ),
-        logger = repo_utils.logger(verbosity_level = 0, name = "python"),
-    )
-
-    env.expect.that_str(py.default_python_version).equals("3.11")
-    rules_python_toolchain = struct(
-        name = "python_3_11",
-        python_version = "3.11",
-        register_coverage_tool = False,
-    )
-    env.expect.that_collection(py.toolchains).contains_exactly([rules_python_toolchain])
-
-_tests.append(_test_default_non_rules_python)
-
 def _test_default_non_rules_python_ignore_root_user_error(env):
     py = parse_modules(
         module_ctx = _mock_mctx(
             _mod(
                 name = "my_module",
                 toolchain = [_toolchain("3.12", ignore_root_user_error = False)],
+                is_root = True,
             ),
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -261,11 +254,12 @@
 _tests.append(_test_default_non_rules_python_ignore_root_user_error)
 
 def _test_default_non_rules_python_ignore_root_user_error_non_root_module(env):
+    """Verify a non-root intermediate module has its ignore_root_user_error setting ignored."""
     py = parse_modules(
         module_ctx = _mock_mctx(
-            _mod(name = "my_module", toolchain = [_toolchain("3.13")]),
+            _mod(name = "my_module", is_root = True, toolchain = [_toolchain("3.13")]),
             _mod(name = "some_module", toolchain = [_toolchain("3.12", ignore_root_user_error = False)]),
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -310,8 +304,9 @@
                     _toolchain("3.11.10"),
                     _toolchain("3.11.13", is_default = True),
                 ],
+                is_root = True,
             ),
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -433,12 +428,68 @@
 
 _tests.append(_test_default_from_defaults_file)
 
+def _test_default_from_single_toolchain(env):
+    py = parse_modules(
+        module_ctx = _mock_mctx(
+            _mod(
+                name = "my_root_module",
+                toolchain = [_toolchain("3.12")],
+                is_root = True,
+            ),
+            _rules_python_module(),
+        ),
+        logger = repo_utils.logger(verbosity_level = 0, name = "python"),
+    )
+    env.expect.that_str(py.default_python_version).equals("3.12")
+
+_tests.append(_test_default_from_single_toolchain)
+
+def _test_defaults_overrides_single_toolchain(env):
+    py = parse_modules(
+        module_ctx = _mock_mctx(
+            _mod(
+                name = "my_root_module",
+                defaults = [
+                    # This relies on rules_python registering 3.11
+                    _defaults(python_version = "3.11"),
+                ],
+                toolchain = [_toolchain("3.12")],
+                is_root = True,
+            ),
+            _rules_python_module(),
+        ),
+        logger = repo_utils.logger(verbosity_level = 0, name = "python"),
+    )
+    env.expect.that_str(py.default_python_version).equals("3.11")
+
+_tests.append(_test_defaults_overrides_single_toolchain)
+
+def _test_defaults_overrides_toolchains_setting_is_default(env):
+    py = parse_modules(
+        module_ctx = _mock_mctx(
+            _mod(
+                name = "my_root_module",
+                defaults = [_defaults(python_version = "3.13")],
+                toolchain = [
+                    _toolchain("3.13"),
+                    _toolchain("3.12", is_default = True),
+                ],
+                is_root = True,
+            ),
+            _rules_python_module(),
+        ),
+        logger = repo_utils.logger(verbosity_level = 0, name = "python"),
+    )
+    env.expect.that_str(py.default_python_version).equals("3.13")
+
+_tests.append(_test_defaults_overrides_toolchains_setting_is_default)
+
 def _test_first_occurance_of_the_toolchain_wins(env):
     py = parse_modules(
         module_ctx = _mock_mctx(
-            _mod(name = "my_module", toolchain = [_toolchain("3.12")]),
+            _mod(name = "my_module", is_root = True, toolchain = [_toolchain("3.12")]),
             _mod(name = "some_module", toolchain = [_toolchain("3.12", configure_coverage_tool = True)]),
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(),
             environ = {
                 "RULES_PYTHON_BZLMOD_DEBUG": "1",
             },
@@ -486,8 +537,9 @@
                         auth_patterns = {"foo": "bar"},
                     ),
                 ],
+                is_root = True,
             ),
-            _mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
+            _rules_python_module(),
         ),
         logger = repo_utils.logger(verbosity_level = 0, name = "python"),
     )
@@ -521,6 +573,7 @@
         module_ctx = _mock_mctx(
             _mod(
                 name = "my_module",
+                is_root = True,
                 toolchain = [_toolchain("3.13")],
                 single_version_override = [
                     _single_version_override(
@@ -601,6 +654,7 @@
         module_ctx = _mock_mctx(
             _mod(
                 name = "my_module",
+                is_root = True,
                 toolchain = [_toolchain("3.13")],
                 single_version_override = [
                     _single_version_override(
@@ -666,6 +720,7 @@
         module_ctx = _mock_mctx(
             _mod(
                 name = "my_module",
+                is_root = True,
                 toolchain = [_toolchain("3.13")],
                 single_version_override = [
                     _single_version_override(
@@ -744,6 +799,7 @@
         module_ctx = _mock_mctx(
             _mod(
                 name = "my_module",
+                is_root = True,
                 toolchain = [_toolchain("3.13")],
                 override = [
                     _override(base_url = "foo"),
@@ -775,6 +831,7 @@
             module_ctx = _mock_mctx(
                 _mod(
                     name = "my_module",
+                    is_root = True,
                     toolchain = [_toolchain("3.13")],
                     single_version_override = test.overrides,
                 ),
@@ -815,6 +872,7 @@
                     name = "my_module",
                     toolchain = [_toolchain("3.13")],
                     single_version_platform_override = test.overrides,
+                    is_root = True,
                 ),
             ),
             _fail = lambda *a: errors.append(" ".join(a)),