Move casting functions to MessageLite and use ClassData as the uniqueness
instead of Reflection. This allows using these functions instead of
`dynamic_cast` for all generated types including LITE.
PiperOrigin-RevId: 632135009
diff --git a/src/google/protobuf/lite_unittest.cc b/src/google/protobuf/lite_unittest.cc
index 9783075..f9f0135 100644
--- a/src/google/protobuf/lite_unittest.cc
+++ b/src/google/protobuf/lite_unittest.cc
@@ -27,6 +27,7 @@
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_lite_test_util.h"
#include "google/protobuf/map_lite_unittest.pb.h"
+#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/test_util_lite.h"
#include "google/protobuf/unittest_lite.pb.h"
@@ -1325,6 +1326,92 @@
}
}
+// Two arbitary types
+using CastType1 = protobuf_unittest::TestAllTypesLite;
+using CastType2 = protobuf_unittest::TestPackedTypesLite;
+
+TEST(LiteTest, DynamicCastToGenerated) {
+ CastType1 test_type_1;
+
+ MessageLite* test_type_1_pointer = &test_type_1;
+ EXPECT_EQ(&test_type_1,
+ DynamicCastToGenerated<CastType1>(test_type_1_pointer));
+ EXPECT_EQ(nullptr, DynamicCastToGenerated<CastType2>(test_type_1_pointer));
+
+ const MessageLite* test_type_1_pointer_const = &test_type_1;
+ EXPECT_EQ(&test_type_1,
+ DynamicCastToGenerated<const CastType1>(test_type_1_pointer_const));
+ EXPECT_EQ(nullptr,
+ DynamicCastToGenerated<const CastType2>(test_type_1_pointer_const));
+
+ MessageLite* test_type_1_pointer_nullptr = nullptr;
+ EXPECT_EQ(nullptr,
+ DynamicCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
+
+ MessageLite& test_type_1_pointer_ref = test_type_1;
+ EXPECT_EQ(&test_type_1,
+ &DynamicCastToGenerated<CastType1>(test_type_1_pointer_ref));
+
+ const MessageLite& test_type_1_pointer_const_ref = test_type_1;
+ EXPECT_EQ(&test_type_1,
+ &DynamicCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
+}
+
+#if GTEST_HAS_DEATH_TEST
+TEST(LiteTest, DynamicCastToGeneratedInvalidReferenceType) {
+ CastType1 test_type_1;
+ const MessageLite& test_type_1_pointer_const_ref = test_type_1;
+ ASSERT_DEATH(DynamicCastToGenerated<CastType2>(test_type_1_pointer_const_ref),
+ "Cannot downcast " + test_type_1.GetTypeName() + " to " +
+ CastType2::default_instance().GetTypeName());
+}
+#endif // GTEST_HAS_DEATH_TEST
+
+TEST(LiteTest, DownCastToGeneratedValidType) {
+ CastType1 test_type_1;
+
+ MessageLite* test_type_1_pointer = &test_type_1;
+ EXPECT_EQ(&test_type_1, DownCastToGenerated<CastType1>(test_type_1_pointer));
+
+ const MessageLite* test_type_1_pointer_const = &test_type_1;
+ EXPECT_EQ(&test_type_1,
+ DownCastToGenerated<const CastType1>(test_type_1_pointer_const));
+
+ MessageLite* test_type_1_pointer_nullptr = nullptr;
+ EXPECT_EQ(nullptr,
+ DownCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
+
+ MessageLite& test_type_1_pointer_ref = test_type_1;
+ EXPECT_EQ(&test_type_1,
+ &DownCastToGenerated<CastType1>(test_type_1_pointer_ref));
+
+ const MessageLite& test_type_1_pointer_const_ref = test_type_1;
+ EXPECT_EQ(&test_type_1,
+ &DownCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
+}
+
+#if GTEST_HAS_DEATH_TEST
+TEST(LiteTest, DownCastToGeneratedInvalidPointerType) {
+ CastType1 test_type_1;
+
+ MessageLite* test_type_1_pointer = &test_type_1;
+
+ ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
+ "Cannot downcast " + test_type_1.GetTypeName() + " to " +
+ CastType2::default_instance().GetTypeName());
+}
+
+TEST(LiteTest, DownCastToGeneratedInvalidReferenceType) {
+ CastType1 test_type_1;
+
+ MessageLite& test_type_1_pointer = test_type_1;
+
+ ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
+ "Cannot downcast " + test_type_1.GetTypeName() + " to " +
+ CastType2::default_instance().GetTypeName());
+}
+#endif // GTEST_HAS_DEATH_TEST
+
} // namespace
} // namespace protobuf
} // namespace google
diff --git a/src/google/protobuf/message.h b/src/google/protobuf/message.h
index 0354c9e..88666e6 100644
--- a/src/google/protobuf/message.h
+++ b/src/google/protobuf/message.h
@@ -1424,90 +1424,6 @@
#undef DECLARE_GET_REPEATED_FIELD
-// Tries to downcast this message to a generated message type. Returns nullptr
-// if this class is not an instance of T. This works even if RTTI is disabled.
-//
-// This also has the effect of creating a strong reference to T that will
-// prevent the linker from stripping it out at link time. This can be important
-// if you are using a DynamicMessageFactory that delegates to the generated
-// factory.
-template <typename T>
-const T* DynamicCastToGenerated(const Message* from) {
- // Compile-time assert that T is a generated type that has a
- // default_instance() accessor, but avoid actually calling it.
- const T& (*get_default_instance)() = &T::default_instance;
- (void)get_default_instance;
-
- // Compile-time assert that T is a subclass of google::protobuf::Message.
- const Message* unused = static_cast<T*>(nullptr);
- (void)unused;
-
-#if PROTOBUF_RTTI
- internal::StrongReferenceToType<T>();
- return dynamic_cast<const T*>(from);
-#else
- bool ok = from != nullptr &&
- T::default_instance().GetReflection() == from->GetReflection();
- return ok ? internal::DownCast<const T*>(from) : nullptr;
-#endif
-}
-
-template <typename T>
-T* DynamicCastToGenerated(Message* from) {
- const Message* message_const = from;
- return const_cast<T*>(DynamicCastToGenerated<T>(message_const));
-}
-
-// An overloaded version of DynamicCastToGenerated for downcasting references to
-// base Message class. If the destination type T if the argument is not an
-// instance of T and dynamic_cast returns nullptr, it terminates with an error.
-template <typename T>
-const T& DynamicCastToGenerated(const Message& from) {
- const T* destination_message = DynamicCastToGenerated<T>(&from);
- ABSL_CHECK(destination_message != nullptr)
- << "Cannot downcast " << from.GetTypeName() << " to "
- << T::default_instance().GetTypeName();
- return *destination_message;
-}
-
-template <typename T>
-T& DynamicCastToGenerated(Message& from) {
- const Message& message_const = from;
- const T& destination_message = DynamicCastToGenerated<T>(message_const);
- return const_cast<T&>(destination_message);
-}
-
-// A lightweight function for downcasting base Message pointer to derived type.
-// It should only be used when the caller is certain that the argument is of
-// instance T and T is a type derived from base Message class.
-template <typename T>
-const T* DownCastToGenerated(const Message* from) {
- internal::StrongReferenceToType<T>();
- ABSL_DCHECK(DynamicCastToGenerated<T>(from) == from)
- << "Cannot downcast " << from->GetTypeName() << " to "
- << T::default_instance().GetTypeName();
-
- return static_cast<const T*>(from);
-}
-
-template <typename T>
-T* DownCastToGenerated(Message* from) {
- const Message* message_const = from;
- return const_cast<T*>(DownCastToGenerated<T>(message_const));
-}
-
-template <typename T>
-const T& DownCastToGenerated(const Message& from) {
- return *DownCastToGenerated<T>(&from);
-}
-
-template <typename T>
-T& DownCastToGenerated(Message& from) {
- const Message& message_const = from;
- const T& destination_message = DownCastToGenerated<T>(message_const);
- return const_cast<T&>(destination_message);
-}
-
// Call this function to ensure that this message's reflection is linked into
// the binary:
//
diff --git a/src/google/protobuf/message_lite.h b/src/google/protobuf/message_lite.h
index 1b85aa0..abb869e 100644
--- a/src/google/protobuf/message_lite.h
+++ b/src/google/protobuf/message_lite.h
@@ -21,6 +21,7 @@
#include <cstdint>
#include <iosfwd>
#include <string>
+#include <type_traits>
#include "absl/base/attributes.h"
#include "absl/log/absl_check.h"
@@ -53,6 +54,7 @@
class Reflection;
class Descriptor;
class AssignDescriptorsHelper;
+class MessageLite;
namespace io {
@@ -120,6 +122,9 @@
#endif
};
+// For MessageLite to friend.
+class TypeId;
+
class SwapFieldHelper;
// See parse_context.h for explanation
@@ -638,6 +643,15 @@
// return a default table instead of a unique one.
virtual const ClassData* GetClassData() const = 0;
+ template <typename T>
+ static auto GetClassDataGenerated() {
+ // We could speed this up if needed by avoiding the function call.
+ // In LTO this is likely inlined, so it might not matter.
+ static_assert(
+ std::is_same<const T&, decltype(T::default_instance())>::value, "");
+ return T::default_instance().T::GetClassData();
+ }
+
internal::InternalMetadata _internal_metadata_;
// Return the cached size object as described by
@@ -682,6 +696,7 @@
friend class internal::LazyField;
friend class internal::SwapFieldHelper;
friend class internal::TcParser;
+ friend class internal::TypeId;
friend class internal::WeakFieldMap;
friend class internal::WireFormatLite;
@@ -717,6 +732,32 @@
namespace internal {
+// A typeinfo equivalent for protobuf message types. Used for
+// DynamicCastToGenerated.
+// We might make this class public later on to have an alternative to
+// `std::type_info` that works when RTTI is disabled.
+class TypeId {
+ public:
+ constexpr explicit TypeId(const MessageLite::ClassData* data) : data_(data) {}
+
+ friend constexpr bool operator==(TypeId a, TypeId b) {
+ return a.data_ == b.data_;
+ }
+ friend constexpr bool operator!=(TypeId a, TypeId b) { return !(a == b); }
+
+ static TypeId Get(const MessageLite& msg) {
+ return TypeId(msg.GetClassData());
+ }
+
+ template <typename T>
+ static TypeId Get() {
+ return TypeId(MessageLite::GetClassDataGenerated<T>());
+ }
+
+ private:
+ const MessageLite::ClassData* data_;
+};
+
template <bool alias>
bool MergeFromImpl(absl::string_view input, MessageLite* msg,
const internal::TcParseTableBase* tc_table,
@@ -820,6 +861,83 @@
std::string ShortFormat(const MessageLite& message_lite);
std::string Utf8Format(const MessageLite& message_lite);
+// Tries to downcast this message to a generated message type. Returns nullptr
+// if this class is not an instance of T. This works even if RTTI is disabled.
+//
+// This also has the effect of creating a strong reference to T that will
+// prevent the linker from stripping it out at link time. This can be important
+// if you are using a DynamicMessageFactory that delegates to the generated
+// factory.
+template <typename T>
+const T* DynamicCastToGenerated(const MessageLite* from) {
+ static_assert(std::is_base_of<MessageLite, T>::value, "");
+
+ internal::StrongReferenceToType<T>();
+ // We might avoid the call to T::GetClassData() altogether if T were to
+ // expose the class data pointer.
+ if (from == nullptr ||
+ internal::TypeId::Get<T>() != internal::TypeId::Get(*from)) {
+ return nullptr;
+ }
+
+ return static_cast<const T*>(from);
+}
+
+template <typename T>
+const T* DynamicCastToGenerated(const MessageLite* from);
+
+template <typename T>
+T* DynamicCastToGenerated(MessageLite* from) {
+ return const_cast<T*>(
+ DynamicCastToGenerated<T>(static_cast<const MessageLite*>(from)));
+}
+
+// An overloaded version of DynamicCastToGenerated for downcasting references to
+// base Message class. If the argument is not an instance of T, it terminates
+// with an error.
+template <typename T>
+const T& DynamicCastToGenerated(const MessageLite& from) {
+ const T* destination_message = DynamicCastToGenerated<T>(&from);
+ ABSL_CHECK(destination_message != nullptr)
+ << "Cannot downcast " << from.GetTypeName() << " to "
+ << T::default_instance().GetTypeName();
+ return *destination_message;
+}
+
+template <typename T>
+T& DynamicCastToGenerated(MessageLite& from) {
+ return const_cast<T&>(
+ DynamicCastToGenerated<T>(static_cast<const MessageLite&>(from)));
+}
+
+// A lightweight function for downcasting base MessageLite pointer to derived
+// type. It should only be used when the caller is certain that the argument is
+// of instance T and T is a generated message type.
+template <typename T>
+const T* DownCastToGenerated(const MessageLite* from) {
+ internal::StrongReferenceToType<T>();
+ ABSL_DCHECK(DynamicCastToGenerated<T>(from) == from)
+ << "Cannot downcast " << from->GetTypeName() << " to "
+ << T::default_instance().GetTypeName();
+ return static_cast<const T*>(from);
+}
+
+template <typename T>
+T* DownCastToGenerated(MessageLite* from) {
+ return const_cast<T*>(
+ DownCastToGenerated<T>(static_cast<const MessageLite*>(from)));
+}
+
+template <typename T>
+const T& DownCastToGenerated(const MessageLite& from) {
+ return *DownCastToGenerated<T>(&from);
+}
+
+template <typename T>
+T& DownCastToGenerated(MessageLite& from) {
+ return *DownCastToGenerated<T>(&from);
+}
+
} // namespace protobuf
} // namespace google
diff --git a/src/google/protobuf/message_unittest.inc b/src/google/protobuf/message_unittest.inc
index 08b00a4..2f647fb 100644
--- a/src/google/protobuf/message_unittest.inc
+++ b/src/google/protobuf/message_unittest.inc
@@ -753,35 +753,35 @@
TEST(MESSAGE_TEST_NAME, DynamicCastToGenerated) {
UNITTEST::TestAllTypes test_all_types;
- Message* test_all_types_pointer = &test_all_types;
+ MessageLite* test_all_types_pointer = &test_all_types;
EXPECT_EQ(&test_all_types, DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer));
EXPECT_EQ(nullptr, DynamicCastToGenerated<UNITTEST::TestRequired>(
test_all_types_pointer));
- const Message* test_all_types_pointer_const = &test_all_types;
+ const MessageLite* test_all_types_pointer_const = &test_all_types;
EXPECT_EQ(&test_all_types,
DynamicCastToGenerated<const UNITTEST::TestAllTypes>(
test_all_types_pointer_const));
EXPECT_EQ(nullptr, DynamicCastToGenerated<const UNITTEST::TestRequired>(
test_all_types_pointer_const));
- Message* test_all_types_pointer_nullptr = nullptr;
+ MessageLite* test_all_types_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr, DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_nullptr));
- Message& test_all_types_pointer_ref = test_all_types;
+ MessageLite& test_all_types_pointer_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_ref));
- const Message& test_all_types_pointer_const_ref = test_all_types;
+ const MessageLite& test_all_types_pointer_const_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_const_ref));
}
TEST(MESSAGE_TEST_NAME, DynamicCastToGeneratedInvalidReferenceType) {
UNITTEST::TestAllTypes test_all_types;
- const Message& test_all_types_pointer_const_ref = test_all_types;
+ const MessageLite& test_all_types_pointer_const_ref = test_all_types;
ASSERT_DEATH(DynamicCastToGenerated<UNITTEST::TestRequired>(
test_all_types_pointer_const_ref),
"Cannot downcast " + test_all_types.GetTypeName() + " to " +
@@ -791,23 +791,23 @@
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedValidType) {
UNITTEST::TestAllTypes test_all_types;
- Message* test_all_types_pointer = &test_all_types;
+ MessageLite* test_all_types_pointer = &test_all_types;
EXPECT_EQ(&test_all_types, DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer));
- const Message* test_all_types_pointer_const = &test_all_types;
+ const MessageLite* test_all_types_pointer_const = &test_all_types;
EXPECT_EQ(&test_all_types, DownCastToGenerated<const UNITTEST::TestAllTypes>(
test_all_types_pointer_const));
- Message* test_all_types_pointer_nullptr = nullptr;
+ MessageLite* test_all_types_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr, DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_nullptr));
- Message& test_all_types_pointer_ref = test_all_types;
+ MessageLite& test_all_types_pointer_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_ref));
- const Message& test_all_types_pointer_const_ref = test_all_types;
+ const MessageLite& test_all_types_pointer_const_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_const_ref));
}
@@ -815,7 +815,7 @@
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedInvalidPointerType) {
UNITTEST::TestAllTypes test_all_types;
- Message* test_all_types_pointer = &test_all_types;
+ MessageLite* test_all_types_pointer = &test_all_types;
ASSERT_DEBUG_DEATH(
DownCastToGenerated<UNITTEST::TestRequired>(test_all_types_pointer),
@@ -826,7 +826,7 @@
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedInvalidReferenceType) {
UNITTEST::TestAllTypes test_all_types;
- Message& test_all_types_pointer = test_all_types;
+ MessageLite& test_all_types_pointer = test_all_types;
ASSERT_DEBUG_DEATH(
DownCastToGenerated<UNITTEST::TestRequired>(test_all_types_pointer),