Rust: remove references to `crate::` from generated code This is necessary for the generated code to work correctly even when it's placed in a module instead of directly in the crate root. Now we can finally delete the last of the gencode post-processing from the protobuf_gencode crate. To get this to work properly I had to update the `Context` object to keep track of the current module depth. This way we know how many `super::` prefixes we need to prepend to an identifier to get back to the top level. I had to some refactoring of our naming helper functions to get everything working properly: - Deleted `GetCrateRelativeQualifiedPath()`, since it seems simpler if we just always provide unambiguous paths. I added some new `RsTypePath()` overloads as a replacement. - Made `RustModuleForContainingType()` a private implementation detail of naming.cc, since the `RustModule()` functions are more user-friendly and accomplish the same thing. - Moved the logic for prepending super:: or the foreign crate name into `RustModuleForContainingType()`. This way, all the helpers that call that function automatically pick up the behavior we want. PiperOrigin-RevId: 703555727
diff --git a/rust/release_crates/protobuf_codegen/src/lib.rs b/rust/release_crates/protobuf_codegen/src/lib.rs index 2997dc3..307a9c3 100644 --- a/rust/release_crates/protobuf_codegen/src/lib.rs +++ b/rust/release_crates/protobuf_codegen/src/lib.rs
@@ -132,21 +132,11 @@ if file_name.ends_with(".upb_minitable.c") { cc_build.file(path); } - if file_name.ends_with(".pb.rs") { - Self::fix_rs_gencode(&path); - } } } cc_build.compile(&format!("{}_upb_gen_code", std::env::var("CARGO_PKG_NAME").unwrap())); Ok(()) } - - // Adjust the generated Rust code to work with the crate structure. - fn fix_rs_gencode(path: &Path) { - let contents = fs::read_to_string(path).unwrap().replace("crate::", ""); - let mut file = OpenOptions::new().write(true).truncate(true).open(path).unwrap(); - file.write(contents.as_bytes()).unwrap(); - } } fn get_path_for_arch() -> Option<PathBuf> {
diff --git a/src/google/protobuf/compiler/rust/context.h b/src/google/protobuf/compiler/rust/context.h index 23a1e62..df5148d 100644 --- a/src/google/protobuf/compiler/rust/context.h +++ b/src/google/protobuf/compiler/rust/context.h
@@ -83,10 +83,11 @@ public: Context(const Options* opts, const RustGeneratorContext* rust_generator_context, - io::Printer* printer) + io::Printer* printer, std::vector<std::string> modules) : opts_(opts), rust_generator_context_(rust_generator_context), - printer_(printer) {} + printer_(printer), + modules_(std::move(modules)) {} Context(const Context&) = delete; Context& operator=(const Context&) = delete; @@ -105,7 +106,7 @@ io::Printer& printer() const { return *printer_; } Context WithPrinter(io::Printer* printer) const { - return Context(opts_, rust_generator_context_, printer); + return Context(opts_, rust_generator_context_, printer, modules_); } // Forwards to Emit(), which will likely be called all the time. @@ -136,10 +137,29 @@ return it->second; } + // Opening and closing modules should always be done with PushModule() and + // PopModule(). Knowing what module we are in is important, because it allows + // us to unambiguously reference other identifiers in the same crate. We + // cannot just use crate::, because when we are building with Cargo, the + // generated code does not necessarily live in the crate root. + void PushModule(absl::string_view name) { + Emit({{"mod_name", name}}, "pub mod $mod_name$ {"); + modules_.emplace_back(name); + } + + void PopModule() { + Emit({{"mod_name", modules_.back()}}, "} // pub mod $mod_name$"); + modules_.pop_back(); + } + + // Returns the current depth of module nesting. + size_t GetModuleDepth() const { return modules_.size(); } + private: const Options* opts_; const RustGeneratorContext* rust_generator_context_; io::Printer* printer_; + std::vector<std::string> modules_; }; bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file);
diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index 21c608b..6451337 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc
@@ -9,6 +9,7 @@ #include <memory> #include <string> +#include <utility> #include <vector> #include "absl/algorithm/container.h" @@ -44,20 +45,20 @@ std::string crate_name = GetCrateName(ctx, *dep); for (int i = 0; i < dep->message_type_count(); ++i) { auto* msg = dep->message_type(i); - auto path = GetCrateRelativeQualifiedPath(ctx, *msg); - ctx.Emit({{"crate", crate_name}, {"pkg::Msg", path}}, + auto path = RsTypePath(ctx, *msg); + ctx.Emit({{"pkg::Msg", path}}, R"rs( - pub use $crate$::$pkg::Msg$; - pub use $crate$::$pkg::Msg$View; - pub use $crate$::$pkg::Msg$Mut; + pub use $pkg::Msg$; + pub use $pkg::Msg$View; + pub use $pkg::Msg$Mut; )rs"); } for (int i = 0; i < dep->enum_type_count(); ++i) { auto* enum_ = dep->enum_type(i); - auto path = GetCrateRelativeQualifiedPath(ctx, *enum_); - ctx.Emit({{"crate", crate_name}, {"pkg::Enum", path}}, + auto path = RsTypePath(ctx, *enum_); + ctx.Emit({{"pkg::Enum", path}}, R"rs( - pub use $crate$::$pkg::Enum$; + pub use $pkg::Enum$; )rs"); } } @@ -101,14 +102,14 @@ std::string relative_mod_path = primary_relpath.Relative(RelativePath(non_primary_file_path)); ctx.Emit({{"file_path", relative_mod_path}, - {"mod_name", RustInternalModuleName(ctx, *non_primary_src)}}, + {"mod_name", RustInternalModuleName(*non_primary_src)}}, R"rs( #[path="$file_path$"] #[allow(non_snake_case)] mod $mod_name$; #[allow(unused_imports)] - pub use crate::$mod_name$::*; + pub use $mod_name$::*; )rs"); } } @@ -138,7 +139,13 @@ RustGeneratorContext rust_generator_context(&files_in_current_crate, &*import_path_to_crate_name); - Context ctx_without_printer(&*opts, &rust_generator_context, nullptr); + std::vector<std::string> modules; + if (file != files_in_current_crate[0]) { + // This is not the primary file, so its generated code will be in a module. + modules.emplace_back(RustInternalModuleName(*file)); + } + Context ctx_without_printer(&*opts, &rust_generator_context, nullptr, + std::move(modules)); auto outfile = absl::WrapUnique( generator_context->Open(GetRsFile(ctx_without_printer, *file)));
diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 4962288..7225da3 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc
@@ -649,8 +649,8 @@ msg.real_oneof_decl_count() == 0) { return; } - ctx.Emit({{"mod_name", RsSafeName(CamelToSnakeCase(msg.name()))}, - {"nested_msgs", + ctx.PushModule(RsSafeName(CamelToSnakeCase(msg.name()))); + ctx.Emit({{"nested_msgs", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { GenerateRs(ctx, *msg.nested_type(i)); @@ -669,13 +669,12 @@ } }}}, R"rs( - pub mod $mod_name$ { $nested_msgs$ $nested_enums$ $oneofs$ - } // mod $mod_name$ )rs"); + ctx.PopModule(); }}, {"raw_arena_getter_for_message", [&] {
diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index aac5067..97d3840 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc
@@ -8,6 +8,7 @@ #include "google/protobuf/compiler/rust/naming.h" #include <algorithm> +#include <cstddef> #include <string> #include <vector> @@ -103,15 +104,6 @@ } template <typename Desc> -std::string GetFullyQualifiedPath(Context& ctx, const Desc& desc) { - auto rel_path = GetCrateRelativeQualifiedPath(ctx, desc); - if (IsInCurrentlyGeneratingCrate(ctx, desc)) { - return absl::StrCat("crate::", rel_path); - } - return absl::StrCat(GetCrateName(ctx, *desc.file()), "::", rel_path); -} - -template <typename Desc> std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc) { return UnderscoreDelimitFullName(ctx, desc.full_name()); } @@ -144,14 +136,22 @@ case RustFieldType::STRING: return "::protobuf::ProtoString"; case RustFieldType::MESSAGE: - return GetFullyQualifiedPath(ctx, *field.message_type()); + return RsTypePath(ctx, *field.message_type()); case RustFieldType::ENUM: - return GetFullyQualifiedPath(ctx, *field.enum_type()); + return RsTypePath(ctx, *field.enum_type()); } ABSL_LOG(ERROR) << "Unknown field type: " << field.type_name(); internal::Unreachable(); } +std::string RsTypePath(Context& ctx, const Descriptor& message) { + return absl::StrCat(RustModule(ctx, message), RsSafeName(message.name())); +} + +std::string RsTypePath(Context& ctx, const EnumDescriptor& descriptor) { + return absl::StrCat(RustModule(ctx, descriptor), EnumRsName(descriptor)); +} + std::string RsViewType(Context& ctx, const FieldDescriptor& field, absl::string_view lifetime) { switch (GetRustFieldType(field)) { @@ -172,20 +172,20 @@ return absl::StrFormat("&%s ::protobuf::ProtoStr", lifetime); case RustFieldType::MESSAGE: if (lifetime.empty()) { - return absl::StrFormat( - "%sView", GetFullyQualifiedPath(ctx, *field.message_type())); + return absl::StrFormat("%sView", + RsTypePath(ctx, *field.message_type())); } else { return absl::StrFormat( - "%sView<%s>", GetFullyQualifiedPath(ctx, *field.message_type()), - lifetime); + "%sView<%s>", RsTypePath(ctx, *field.message_type()), lifetime); } } ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name(); internal::Unreachable(); } -std::string RustModuleForContainingType(Context& ctx, - const Descriptor* containing_type) { +static std::string RustModuleForContainingType( + Context& ctx, const Descriptor* containing_type, + const FileDescriptor& file) { std::vector<std::string> modules; // Innermost to outermost order. @@ -204,35 +204,37 @@ modules.push_back(""); } - return absl::StrJoin(modules, "::"); + std::string crate_relative = absl::StrJoin(modules, "::"); + + if (IsInCurrentlyGeneratingCrate(ctx, file)) { + std::string prefix; + for (size_t i = 0; i < ctx.GetModuleDepth(); ++i) { + prefix += "super::"; + } + return absl::StrCat(prefix, crate_relative); + } + return absl::StrCat(GetCrateName(ctx, file), "::", crate_relative); } std::string RustModule(Context& ctx, const Descriptor& msg) { - return RustModuleForContainingType(ctx, msg.containing_type()); + return RustModuleForContainingType(ctx, msg.containing_type(), *msg.file()); } std::string RustModule(Context& ctx, const EnumDescriptor& enum_) { - return RustModuleForContainingType(ctx, enum_.containing_type()); + return RustModuleForContainingType(ctx, enum_.containing_type(), + *enum_.file()); } std::string RustModule(Context& ctx, const OneofDescriptor& oneof) { - return RustModuleForContainingType(ctx, oneof.containing_type()); + return RustModuleForContainingType(ctx, oneof.containing_type(), + *oneof.file()); } -std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) { +std::string RustInternalModuleName(const FileDescriptor& file) { return RsSafeName( absl::StrReplaceAll(StripProto(file.name()), {{"_", "__"}, {"/", "_s"}})); } -std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) { - return absl::StrCat(RustModule(ctx, msg), RsSafeName(msg.name())); -} - -std::string GetCrateRelativeQualifiedPath(Context& ctx, - const EnumDescriptor& enum_) { - return absl::StrCat(RustModule(ctx, enum_), EnumRsName(enum_)); -} - std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) { absl::string_view label = field.is_repeated() ? "repeated" : "optional"; std::string comment = absl::StrCat(field.name(), ": ", label, " ",
diff --git a/src/google/protobuf/compiler/rust/naming.h b/src/google/protobuf/compiler/rust/naming.h index 2707d85..95a2d14 100644 --- a/src/google/protobuf/compiler/rust/naming.h +++ b/src/google/protobuf/compiler/rust/naming.h
@@ -38,10 +38,12 @@ std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc, absl::string_view key_t, absl::string_view op); -// Returns an absolute path to the Proxied Rust type of the given field. -// The absolute path is guaranteed to work in the crate that defines the field. -// It may be crate-relative, or directly reference the owning crate of the type. +// Returns a path to the Proxied Rust type of the given field. The path will be +// relative if the type is in the same crate, or absolute if it is in a +// different crate. std::string RsTypePath(Context& ctx, const FieldDescriptor& field); +std::string RsTypePath(Context& ctx, const Descriptor& message); +std::string RsTypePath(Context& ctx, const EnumDescriptor& descriptor); // Returns the 'simple spelling' of the Rust View type for the provided field. // For example, `i32` for int32 fields and `SomeMsgView<'$lifetime$>` for @@ -84,27 +86,18 @@ // verbatim unless it is a Rust keyword that isn't a legal symbol name. std::string RsSafeName(absl::string_view name); -// Constructs a string of the Rust modules which will contain the message. +// Constructs a string of the Rust modules which will contain the entity. // // Example: Given a message 'NestedMessage' which is defined in package 'x.y' // which is inside 'ParentMessage', the message will be placed in the -// x::y::ParentMessage_ Rust module, so this function will return the string -// "x::y::ParentMessage_::". -// -// If the message has no package and no containing messages then this returns -// empty string. -std::string RustModuleForContainingType(Context& ctx, - const Descriptor* containing_type); +// x::y::parent_message Rust module, so this function will return +// "x::y::parent_message::", with the necessary prefix to make it relative to +// the current scope, or absolute if the entity is in a different crate. std::string RustModule(Context& ctx, const Descriptor& msg); std::string RustModule(Context& ctx, const EnumDescriptor& enum_); std::string RustModule(Context& ctx, const OneofDescriptor& oneof); -std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file); -std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg); -std::string GetCrateRelativeQualifiedPath(Context& ctx, - const EnumDescriptor& enum_); -std::string GetCrateRelativeQualifiedPath(Context& ctx, - const OneofDescriptor& oneof); +std::string RustInternalModuleName(const FileDescriptor& file); template <typename Desc> std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc);
diff --git a/src/google/protobuf/compiler/rust/naming_test.cc b/src/google/protobuf/compiler/rust/naming_test.cc index b1acd7b..f256411 100644 --- a/src/google/protobuf/compiler/rust/naming_test.cc +++ b/src/google/protobuf/compiler/rust/naming_test.cc
@@ -5,19 +5,12 @@ #include <gtest/gtest.h> #include "absl/container/flat_hash_map.h" -#include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" using google::protobuf::compiler::rust::CamelToSnakeCase; -using google::protobuf::compiler::rust::Context; -using google::protobuf::compiler::rust::Kernel; -using google::protobuf::compiler::rust::Options; -using google::protobuf::compiler::rust::RustGeneratorContext; using google::protobuf::compiler::rust::RustInternalModuleName; using google::protobuf::compiler::rust::ScreamingSnakeToUpperCamelCase; -using google::protobuf::io::Printer; -using google::protobuf::io::StringOutputStream; namespace { TEST(RustProtoNaming, RustInternalModuleName) { @@ -25,17 +18,7 @@ foo_file.set_name("strong_bad/lol.proto"); google::protobuf::DescriptorPool pool; const google::protobuf::FileDescriptor* fd = pool.BuildFile(foo_file); - - const Options opts = {Kernel::kUpb}; - std::vector<const google::protobuf::FileDescriptor*> files{fd}; - absl::flat_hash_map<std::string, std::string> mapping; - const RustGeneratorContext rust_generator_context(&files, &mapping); - std::string output; - StringOutputStream stream{&output}; - Printer printer(&stream); - Context c = Context(&opts, &rust_generator_context, &printer); - - EXPECT_EQ(RustInternalModuleName(c, *fd), "strong__bad_slol"); + EXPECT_EQ(RustInternalModuleName(*fd), "strong__bad_slol"); } TEST(RustProtoNaming, CamelToSnakeCase) {
diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc index eb2afe9..5702c57 100644 --- a/src/google/protobuf/compiler/rust/oneof.cc +++ b/src/google/protobuf/compiler/rust/oneof.cc
@@ -232,9 +232,7 @@ {{"oneof_name", RsSafeName(oneof.name())}, {"view_lifetime", ViewLifetime(accessor_case)}, {"self", ViewReceiver(accessor_case)}, - {"oneof_enum_module", - absl::StrCat("crate::", RustModuleForContainingType( - ctx, oneof.containing_type()))}, + {"oneof_enum_module", RustModule(ctx, oneof)}, {"view_enum_name", OneofViewEnumRsName(oneof)}, {"case_enum_name", OneofCaseEnumRsName(oneof)}, {"view_cases", @@ -301,9 +299,7 @@ ctx.Emit( { - {"oneof_enum_module", - absl::StrCat("crate::", RustModuleForContainingType( - ctx, oneof.containing_type()))}, + {"oneof_enum_module", RustModule(ctx, oneof)}, {"case_enum_rs_name", OneofCaseEnumRsName(oneof)}, {"case_thunk", ThunkName(ctx, oneof, "case")}, },