feat: pyo3 support module prefix + naming (#3726)

Adds support for specifying the module name + module prefix for pyo3
extensions.

---------

Co-authored-by: UebelAndre <github@uebelandre.com>
diff --git a/extensions/pyo3/private/pyo3.bzl b/extensions/pyo3/private/pyo3.bzl
index 0e087b5..b16871e 100644
--- a/extensions/pyo3/private/pyo3.bzl
+++ b/extensions/pyo3/private/pyo3.bzl
@@ -87,10 +87,43 @@
     is_windows = extension.basename.endswith(".dll")
 
     # https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
-    ext = ctx.actions.declare_file("{}{}".format(
-        ctx.label.name,
-        ".pyd" if is_windows else ".so",
-    ))
+    #
+    # Determine the on-disk and logical Python module layout.
+    #
+    # `module` is a full dotted module path (e.g. "foo.bar"). We split on the
+    # last "." such that:
+    #   - module_prefix == "foo"
+    #   - module_name == "bar"
+    #
+    # `module_name` must match the `#[pymodule] fn <name>(...)` in the Rust code
+    # and is also what we pass to the stub generator.
+    module_path = ctx.attr.module_name if ctx.attr.module_name else ctx.label.name.replace("/", ".")
+
+    if module_path.startswith(".") or module_path.endswith(".") or ".." in module_path:
+        fail("Invalid `module` value '{}': expected a dotted module path like 'foo.bar'.".format(module_path))
+
+    last_dot = module_path.rfind(".")
+    if last_dot == -1:
+        module_prefix = None
+        module_name = module_path
+    else:
+        module_prefix = module_path[:last_dot]
+        module_name = module_path[last_dot + 1:]
+
+    if not module_name:
+        fail("Invalid `module` value '{}': module name may not be empty.".format(module_path))
+
+    # Convert module_prefix (e.g. "foo.bar") into a path ("foo/bar") and place
+    # the extension and stubs in the corresponding directory.
+    if module_prefix:
+        module_prefix_path = module_prefix.replace(".", "/")
+        module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
+        stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
+    else:
+        module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
+        stub_relpath = "{}.pyi".format(module_name)
+
+    ext = ctx.actions.declare_file(module_relpath)
     ctx.actions.symlink(
         output = ext,
         target_file = extension,
@@ -99,10 +132,10 @@
 
     stub = None
     if _stubs_enabled(ctx.attr.stubs, toolchain):
-        stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
+        stub = ctx.actions.declare_file(stub_relpath)
 
         args = ctx.actions.args()
-        args.add(ctx.label.name, format = "--module_name=%s")
+        args.add(module_name, format = "--module_name=%s")
         args.add(ext, format = "--module_path=%s")
         args.add(stub, format = "--output=%s")
         ctx.actions.run(
@@ -180,6 +213,9 @@
         "imports": attr.string_list(
             doc = "List of import directories to be added to the `PYTHONPATH`.",
         ),
+        "module_name": attr.string(
+            doc = "A full dotted Python module path implemented by this extension (e.g. `foo.bar`).",
+        ),
         "stubs": attr.int(
             doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
             default = -1,
@@ -218,6 +254,7 @@
         stubs = None,
         version = None,
         compilation_mode = "opt",
+        module_name = None,
         **kwargs):
     """Define a PyO3 python extension module.
 
@@ -259,6 +296,7 @@
             For more details see [rust_shared_library][rsl].
         compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
             value to build the extension for. If set to `"current"`, the current configuration will be used.
+        module_name (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`).
         **kwargs (dict): Additional keyword arguments.
     """
     tags = kwargs.pop("tags", [])
@@ -318,6 +356,7 @@
         compilation_mode = compilation_mode,
         stubs = stubs_int,
         imports = imports,
+        module_name = module_name,
         tags = tags,
         visibility = visibility,
         **kwargs
diff --git a/extensions/pyo3/test/module_prefix/BUILD.bazel b/extensions/pyo3/test/module_prefix/BUILD.bazel
new file mode 100644
index 0000000..72a830e
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix/BUILD.bazel
@@ -0,0 +1,15 @@
+load("@rules_python//python:defs.bzl", "py_test")
+load("//:defs.bzl", "pyo3_extension")
+
+pyo3_extension(
+    name = "module_prefix",
+    srcs = ["bar.rs"],
+    edition = "2021",
+    module_name = "foo.bar",
+)
+
+py_test(
+    name = "module_prefix_import_test",
+    srcs = ["module_prefix_import_test.py"],
+    deps = [":module_prefix"],
+)
diff --git a/extensions/pyo3/test/module_prefix/bar.rs b/extensions/pyo3/test/module_prefix/bar.rs
new file mode 100644
index 0000000..cfddb13
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix/bar.rs
@@ -0,0 +1,12 @@
+use pyo3::prelude::*;
+
+#[pyfunction]
+fn thing() -> PyResult<&'static str> {
+    Ok("hello from rust")
+}
+
+#[pymodule]
+fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
+    m.add_function(wrap_pyfunction!(thing, m)?)?;
+    Ok(())
+}
diff --git a/extensions/pyo3/test/module_prefix/module_prefix_import_test.py b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py
new file mode 100644
index 0000000..1175fa9
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py
@@ -0,0 +1,18 @@
+"""Tests that a pyo3 extension can be imported via a module prefix."""
+
+import unittest
+from test.module_prefix.foo import bar
+
+
+class ModulePrefixImportTest(unittest.TestCase):
+    """Test Class."""
+
+    def test_import_and_call(self) -> None:
+        """Test that a pyo3 extension can be imported via a module prefix."""
+
+        result = bar.thing()
+        self.assertEqual("hello from rust", result)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/extensions/pyo3/test/module_prefix_imports/BUILD.bazel b/extensions/pyo3/test/module_prefix_imports/BUILD.bazel
new file mode 100644
index 0000000..6005fed
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix_imports/BUILD.bazel
@@ -0,0 +1,18 @@
+load("@rules_python//python:defs.bzl", "py_test")
+load("//:defs.bzl", "pyo3_extension")
+
+# Variant that relies on the `imports` attribute to put this package root on
+# `PYTHONPATH` (so you can import `foo.bar` directly).
+pyo3_extension(
+    name = "module_prefix_imports",
+    srcs = ["bar.rs"],
+    edition = "2021",
+    imports = ["."],
+    module_name = "foo.bar",
+)
+
+py_test(
+    name = "module_prefix_imports_test",
+    srcs = ["module_prefix_imports_test.py"],
+    deps = [":module_prefix_imports"],
+)
diff --git a/extensions/pyo3/test/module_prefix_imports/bar.rs b/extensions/pyo3/test/module_prefix_imports/bar.rs
new file mode 100644
index 0000000..fcd79c5
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix_imports/bar.rs
@@ -0,0 +1,13 @@
+use pyo3::prelude::*;
+
+#[pyfunction]
+fn thing() -> PyResult<&'static str> {
+    Ok("hello from rust")
+}
+
+#[pymodule]
+fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
+    m.add_function(wrap_pyfunction!(thing, m)?)?;
+    Ok(())
+}
+
diff --git a/extensions/pyo3/test/module_prefix_imports/module_prefix_imports_test.py b/extensions/pyo3/test/module_prefix_imports/module_prefix_imports_test.py
new file mode 100644
index 0000000..8a245f5
--- /dev/null
+++ b/extensions/pyo3/test/module_prefix_imports/module_prefix_imports_test.py
@@ -0,0 +1,19 @@
+"""Tests importing a pyo3 extension via `imports = ["."]`."""
+
+import unittest
+
+from foo import bar  # type: ignore[import-untyped]
+
+
+class ModulePrefixImportsTest(unittest.TestCase):
+    """Test Class."""
+
+    def test_import_and_call(self) -> None:
+        """Test that a pyo3 extension can be imported via a module prefix."""
+
+        result = bar.thing()
+        self.assertEqual("hello from rust", result)
+
+
+if __name__ == "__main__":
+    unittest.main()