Rust: remove references to `crate::` from generated code
This is necessary for the generated code to work correctly even when it's
placed in a module instead of directly in the crate root. Now we can finally
delete the last of the gencode post-processing from the protobuf_gencode crate.
To get this to work properly I had to update the `Context` object to keep track
of the current module depth. This way we know how many `super::` prefixes we
need to prepend to an identifier to get back to the top level.
I had to some refactoring of our naming helper functions to get everything
working properly:
- Deleted `GetCrateRelativeQualifiedPath()`, since it seems simpler if we just
always provide unambiguous paths. I added some new `RsTypePath()` overloads
as a replacement.
- Made `RustModuleForContainingType()` a private implementation detail of
naming.cc, since the `RustModule()` functions are more user-friendly and
accomplish the same thing.
- Moved the logic for prepending super:: or the foreign crate name into
`RustModuleForContainingType()`. This way, all the helpers that call that
function automatically pick up the behavior we want.
PiperOrigin-RevId: 703555727
diff --git a/rust/release_crates/protobuf_codegen/src/lib.rs b/rust/release_crates/protobuf_codegen/src/lib.rs
index 2997dc3..307a9c3 100644
--- a/rust/release_crates/protobuf_codegen/src/lib.rs
+++ b/rust/release_crates/protobuf_codegen/src/lib.rs
@@ -132,21 +132,11 @@
if file_name.ends_with(".upb_minitable.c") {
cc_build.file(path);
}
- if file_name.ends_with(".pb.rs") {
- Self::fix_rs_gencode(&path);
- }
}
}
cc_build.compile(&format!("{}_upb_gen_code", std::env::var("CARGO_PKG_NAME").unwrap()));
Ok(())
}
-
- // Adjust the generated Rust code to work with the crate structure.
- fn fix_rs_gencode(path: &Path) {
- let contents = fs::read_to_string(path).unwrap().replace("crate::", "");
- let mut file = OpenOptions::new().write(true).truncate(true).open(path).unwrap();
- file.write(contents.as_bytes()).unwrap();
- }
}
fn get_path_for_arch() -> Option<PathBuf> {
diff --git a/src/google/protobuf/compiler/rust/context.h b/src/google/protobuf/compiler/rust/context.h
index 23a1e62..df5148d 100644
--- a/src/google/protobuf/compiler/rust/context.h
+++ b/src/google/protobuf/compiler/rust/context.h
@@ -83,10 +83,11 @@
public:
Context(const Options* opts,
const RustGeneratorContext* rust_generator_context,
- io::Printer* printer)
+ io::Printer* printer, std::vector<std::string> modules)
: opts_(opts),
rust_generator_context_(rust_generator_context),
- printer_(printer) {}
+ printer_(printer),
+ modules_(std::move(modules)) {}
Context(const Context&) = delete;
Context& operator=(const Context&) = delete;
@@ -105,7 +106,7 @@
io::Printer& printer() const { return *printer_; }
Context WithPrinter(io::Printer* printer) const {
- return Context(opts_, rust_generator_context_, printer);
+ return Context(opts_, rust_generator_context_, printer, modules_);
}
// Forwards to Emit(), which will likely be called all the time.
@@ -136,10 +137,29 @@
return it->second;
}
+ // Opening and closing modules should always be done with PushModule() and
+ // PopModule(). Knowing what module we are in is important, because it allows
+ // us to unambiguously reference other identifiers in the same crate. We
+ // cannot just use crate::, because when we are building with Cargo, the
+ // generated code does not necessarily live in the crate root.
+ void PushModule(absl::string_view name) {
+ Emit({{"mod_name", name}}, "pub mod $mod_name$ {");
+ modules_.emplace_back(name);
+ }
+
+ void PopModule() {
+ Emit({{"mod_name", modules_.back()}}, "} // pub mod $mod_name$");
+ modules_.pop_back();
+ }
+
+ // Returns the current depth of module nesting.
+ size_t GetModuleDepth() const { return modules_.size(); }
+
private:
const Options* opts_;
const RustGeneratorContext* rust_generator_context_;
io::Printer* printer_;
+ std::vector<std::string> modules_;
};
bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file);
diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc
index 21c608b..6451337 100644
--- a/src/google/protobuf/compiler/rust/generator.cc
+++ b/src/google/protobuf/compiler/rust/generator.cc
@@ -9,6 +9,7 @@
#include <memory>
#include <string>
+#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
@@ -44,20 +45,20 @@
std::string crate_name = GetCrateName(ctx, *dep);
for (int i = 0; i < dep->message_type_count(); ++i) {
auto* msg = dep->message_type(i);
- auto path = GetCrateRelativeQualifiedPath(ctx, *msg);
- ctx.Emit({{"crate", crate_name}, {"pkg::Msg", path}},
+ auto path = RsTypePath(ctx, *msg);
+ ctx.Emit({{"pkg::Msg", path}},
R"rs(
- pub use $crate$::$pkg::Msg$;
- pub use $crate$::$pkg::Msg$View;
- pub use $crate$::$pkg::Msg$Mut;
+ pub use $pkg::Msg$;
+ pub use $pkg::Msg$View;
+ pub use $pkg::Msg$Mut;
)rs");
}
for (int i = 0; i < dep->enum_type_count(); ++i) {
auto* enum_ = dep->enum_type(i);
- auto path = GetCrateRelativeQualifiedPath(ctx, *enum_);
- ctx.Emit({{"crate", crate_name}, {"pkg::Enum", path}},
+ auto path = RsTypePath(ctx, *enum_);
+ ctx.Emit({{"pkg::Enum", path}},
R"rs(
- pub use $crate$::$pkg::Enum$;
+ pub use $pkg::Enum$;
)rs");
}
}
@@ -101,14 +102,14 @@
std::string relative_mod_path =
primary_relpath.Relative(RelativePath(non_primary_file_path));
ctx.Emit({{"file_path", relative_mod_path},
- {"mod_name", RustInternalModuleName(ctx, *non_primary_src)}},
+ {"mod_name", RustInternalModuleName(*non_primary_src)}},
R"rs(
#[path="$file_path$"]
#[allow(non_snake_case)]
mod $mod_name$;
#[allow(unused_imports)]
- pub use crate::$mod_name$::*;
+ pub use $mod_name$::*;
)rs");
}
}
@@ -138,7 +139,13 @@
RustGeneratorContext rust_generator_context(&files_in_current_crate,
&*import_path_to_crate_name);
- Context ctx_without_printer(&*opts, &rust_generator_context, nullptr);
+ std::vector<std::string> modules;
+ if (file != files_in_current_crate[0]) {
+ // This is not the primary file, so its generated code will be in a module.
+ modules.emplace_back(RustInternalModuleName(*file));
+ }
+ Context ctx_without_printer(&*opts, &rust_generator_context, nullptr,
+ std::move(modules));
auto outfile = absl::WrapUnique(
generator_context->Open(GetRsFile(ctx_without_printer, *file)));
diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc
index 4962288..7225da3 100644
--- a/src/google/protobuf/compiler/rust/message.cc
+++ b/src/google/protobuf/compiler/rust/message.cc
@@ -649,8 +649,8 @@
msg.real_oneof_decl_count() == 0) {
return;
}
- ctx.Emit({{"mod_name", RsSafeName(CamelToSnakeCase(msg.name()))},
- {"nested_msgs",
+ ctx.PushModule(RsSafeName(CamelToSnakeCase(msg.name())));
+ ctx.Emit({{"nested_msgs",
[&] {
for (int i = 0; i < msg.nested_type_count(); ++i) {
GenerateRs(ctx, *msg.nested_type(i));
@@ -669,13 +669,12 @@
}
}}},
R"rs(
- pub mod $mod_name$ {
$nested_msgs$
$nested_enums$
$oneofs$
- } // mod $mod_name$
)rs");
+ ctx.PopModule();
}},
{"raw_arena_getter_for_message",
[&] {
diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc
index aac5067..97d3840 100644
--- a/src/google/protobuf/compiler/rust/naming.cc
+++ b/src/google/protobuf/compiler/rust/naming.cc
@@ -8,6 +8,7 @@
#include "google/protobuf/compiler/rust/naming.h"
#include <algorithm>
+#include <cstddef>
#include <string>
#include <vector>
@@ -103,15 +104,6 @@
}
template <typename Desc>
-std::string GetFullyQualifiedPath(Context& ctx, const Desc& desc) {
- auto rel_path = GetCrateRelativeQualifiedPath(ctx, desc);
- if (IsInCurrentlyGeneratingCrate(ctx, desc)) {
- return absl::StrCat("crate::", rel_path);
- }
- return absl::StrCat(GetCrateName(ctx, *desc.file()), "::", rel_path);
-}
-
-template <typename Desc>
std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc) {
return UnderscoreDelimitFullName(ctx, desc.full_name());
}
@@ -144,14 +136,22 @@
case RustFieldType::STRING:
return "::protobuf::ProtoString";
case RustFieldType::MESSAGE:
- return GetFullyQualifiedPath(ctx, *field.message_type());
+ return RsTypePath(ctx, *field.message_type());
case RustFieldType::ENUM:
- return GetFullyQualifiedPath(ctx, *field.enum_type());
+ return RsTypePath(ctx, *field.enum_type());
}
ABSL_LOG(ERROR) << "Unknown field type: " << field.type_name();
internal::Unreachable();
}
+std::string RsTypePath(Context& ctx, const Descriptor& message) {
+ return absl::StrCat(RustModule(ctx, message), RsSafeName(message.name()));
+}
+
+std::string RsTypePath(Context& ctx, const EnumDescriptor& descriptor) {
+ return absl::StrCat(RustModule(ctx, descriptor), EnumRsName(descriptor));
+}
+
std::string RsViewType(Context& ctx, const FieldDescriptor& field,
absl::string_view lifetime) {
switch (GetRustFieldType(field)) {
@@ -172,20 +172,20 @@
return absl::StrFormat("&%s ::protobuf::ProtoStr", lifetime);
case RustFieldType::MESSAGE:
if (lifetime.empty()) {
- return absl::StrFormat(
- "%sView", GetFullyQualifiedPath(ctx, *field.message_type()));
+ return absl::StrFormat("%sView",
+ RsTypePath(ctx, *field.message_type()));
} else {
return absl::StrFormat(
- "%sView<%s>", GetFullyQualifiedPath(ctx, *field.message_type()),
- lifetime);
+ "%sView<%s>", RsTypePath(ctx, *field.message_type()), lifetime);
}
}
ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
internal::Unreachable();
}
-std::string RustModuleForContainingType(Context& ctx,
- const Descriptor* containing_type) {
+static std::string RustModuleForContainingType(
+ Context& ctx, const Descriptor* containing_type,
+ const FileDescriptor& file) {
std::vector<std::string> modules;
// Innermost to outermost order.
@@ -204,35 +204,37 @@
modules.push_back("");
}
- return absl::StrJoin(modules, "::");
+ std::string crate_relative = absl::StrJoin(modules, "::");
+
+ if (IsInCurrentlyGeneratingCrate(ctx, file)) {
+ std::string prefix;
+ for (size_t i = 0; i < ctx.GetModuleDepth(); ++i) {
+ prefix += "super::";
+ }
+ return absl::StrCat(prefix, crate_relative);
+ }
+ return absl::StrCat(GetCrateName(ctx, file), "::", crate_relative);
}
std::string RustModule(Context& ctx, const Descriptor& msg) {
- return RustModuleForContainingType(ctx, msg.containing_type());
+ return RustModuleForContainingType(ctx, msg.containing_type(), *msg.file());
}
std::string RustModule(Context& ctx, const EnumDescriptor& enum_) {
- return RustModuleForContainingType(ctx, enum_.containing_type());
+ return RustModuleForContainingType(ctx, enum_.containing_type(),
+ *enum_.file());
}
std::string RustModule(Context& ctx, const OneofDescriptor& oneof) {
- return RustModuleForContainingType(ctx, oneof.containing_type());
+ return RustModuleForContainingType(ctx, oneof.containing_type(),
+ *oneof.file());
}
-std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) {
+std::string RustInternalModuleName(const FileDescriptor& file) {
return RsSafeName(
absl::StrReplaceAll(StripProto(file.name()), {{"_", "__"}, {"/", "_s"}}));
}
-std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) {
- return absl::StrCat(RustModule(ctx, msg), RsSafeName(msg.name()));
-}
-
-std::string GetCrateRelativeQualifiedPath(Context& ctx,
- const EnumDescriptor& enum_) {
- return absl::StrCat(RustModule(ctx, enum_), EnumRsName(enum_));
-}
-
std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) {
absl::string_view label = field.is_repeated() ? "repeated" : "optional";
std::string comment = absl::StrCat(field.name(), ": ", label, " ",
diff --git a/src/google/protobuf/compiler/rust/naming.h b/src/google/protobuf/compiler/rust/naming.h
index 2707d85..95a2d14 100644
--- a/src/google/protobuf/compiler/rust/naming.h
+++ b/src/google/protobuf/compiler/rust/naming.h
@@ -38,10 +38,12 @@
std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc,
absl::string_view key_t, absl::string_view op);
-// Returns an absolute path to the Proxied Rust type of the given field.
-// The absolute path is guaranteed to work in the crate that defines the field.
-// It may be crate-relative, or directly reference the owning crate of the type.
+// Returns a path to the Proxied Rust type of the given field. The path will be
+// relative if the type is in the same crate, or absolute if it is in a
+// different crate.
std::string RsTypePath(Context& ctx, const FieldDescriptor& field);
+std::string RsTypePath(Context& ctx, const Descriptor& message);
+std::string RsTypePath(Context& ctx, const EnumDescriptor& descriptor);
// Returns the 'simple spelling' of the Rust View type for the provided field.
// For example, `i32` for int32 fields and `SomeMsgView<'$lifetime$>` for
@@ -84,27 +86,18 @@
// verbatim unless it is a Rust keyword that isn't a legal symbol name.
std::string RsSafeName(absl::string_view name);
-// Constructs a string of the Rust modules which will contain the message.
+// Constructs a string of the Rust modules which will contain the entity.
//
// Example: Given a message 'NestedMessage' which is defined in package 'x.y'
// which is inside 'ParentMessage', the message will be placed in the
-// x::y::ParentMessage_ Rust module, so this function will return the string
-// "x::y::ParentMessage_::".
-//
-// If the message has no package and no containing messages then this returns
-// empty string.
-std::string RustModuleForContainingType(Context& ctx,
- const Descriptor* containing_type);
+// x::y::parent_message Rust module, so this function will return
+// "x::y::parent_message::", with the necessary prefix to make it relative to
+// the current scope, or absolute if the entity is in a different crate.
std::string RustModule(Context& ctx, const Descriptor& msg);
std::string RustModule(Context& ctx, const EnumDescriptor& enum_);
std::string RustModule(Context& ctx, const OneofDescriptor& oneof);
-std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file);
-std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg);
-std::string GetCrateRelativeQualifiedPath(Context& ctx,
- const EnumDescriptor& enum_);
-std::string GetCrateRelativeQualifiedPath(Context& ctx,
- const OneofDescriptor& oneof);
+std::string RustInternalModuleName(const FileDescriptor& file);
template <typename Desc>
std::string GetUnderscoreDelimitedFullName(Context& ctx, const Desc& desc);
diff --git a/src/google/protobuf/compiler/rust/naming_test.cc b/src/google/protobuf/compiler/rust/naming_test.cc
index b1acd7b..f256411 100644
--- a/src/google/protobuf/compiler/rust/naming_test.cc
+++ b/src/google/protobuf/compiler/rust/naming_test.cc
@@ -5,19 +5,12 @@
#include <gtest/gtest.h>
#include "absl/container/flat_hash_map.h"
-#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
using google::protobuf::compiler::rust::CamelToSnakeCase;
-using google::protobuf::compiler::rust::Context;
-using google::protobuf::compiler::rust::Kernel;
-using google::protobuf::compiler::rust::Options;
-using google::protobuf::compiler::rust::RustGeneratorContext;
using google::protobuf::compiler::rust::RustInternalModuleName;
using google::protobuf::compiler::rust::ScreamingSnakeToUpperCamelCase;
-using google::protobuf::io::Printer;
-using google::protobuf::io::StringOutputStream;
namespace {
TEST(RustProtoNaming, RustInternalModuleName) {
@@ -25,17 +18,7 @@
foo_file.set_name("strong_bad/lol.proto");
google::protobuf::DescriptorPool pool;
const google::protobuf::FileDescriptor* fd = pool.BuildFile(foo_file);
-
- const Options opts = {Kernel::kUpb};
- std::vector<const google::protobuf::FileDescriptor*> files{fd};
- absl::flat_hash_map<std::string, std::string> mapping;
- const RustGeneratorContext rust_generator_context(&files, &mapping);
- std::string output;
- StringOutputStream stream{&output};
- Printer printer(&stream);
- Context c = Context(&opts, &rust_generator_context, &printer);
-
- EXPECT_EQ(RustInternalModuleName(c, *fd), "strong__bad_slol");
+ EXPECT_EQ(RustInternalModuleName(*fd), "strong__bad_slol");
}
TEST(RustProtoNaming, CamelToSnakeCase) {
diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc
index eb2afe9..5702c57 100644
--- a/src/google/protobuf/compiler/rust/oneof.cc
+++ b/src/google/protobuf/compiler/rust/oneof.cc
@@ -232,9 +232,7 @@
{{"oneof_name", RsSafeName(oneof.name())},
{"view_lifetime", ViewLifetime(accessor_case)},
{"self", ViewReceiver(accessor_case)},
- {"oneof_enum_module",
- absl::StrCat("crate::", RustModuleForContainingType(
- ctx, oneof.containing_type()))},
+ {"oneof_enum_module", RustModule(ctx, oneof)},
{"view_enum_name", OneofViewEnumRsName(oneof)},
{"case_enum_name", OneofCaseEnumRsName(oneof)},
{"view_cases",
@@ -301,9 +299,7 @@
ctx.Emit(
{
- {"oneof_enum_module",
- absl::StrCat("crate::", RustModuleForContainingType(
- ctx, oneof.containing_type()))},
+ {"oneof_enum_module", RustModule(ctx, oneof)},
{"case_enum_rs_name", OneofCaseEnumRsName(oneof)},
{"case_thunk", ThunkName(ctx, oneof, "case")},
},