| #include "google/protobuf/map.h" |
| |
| #include <algorithm> |
| #include <cstddef> |
| #include <cstdint> |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| |
| #include "absl/log/absl_log.h" |
| #include "google/protobuf/message.h" |
| #include "google/protobuf/message_lite.h" |
| #include "rust/cpp_kernel/strings.h" |
| |
| namespace google { |
| namespace protobuf { |
| namespace rust { |
| namespace { |
| |
| using MapValueTag = internal::UntypedMapBase::TypeKind; |
| |
| // LINT.IfChange(map_ffi) |
| struct MapValue { |
| MapValueTag tag; |
| union { |
| bool b; |
| uint32_t u32; |
| uint64_t u64; |
| float f32; |
| double f64; |
| std::string* s; |
| google::protobuf::MessageLite* message; |
| }; |
| }; |
| // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi) |
| |
| template <typename T> |
| struct FromViewType { |
| using type = T; |
| }; |
| |
| template <> |
| struct FromViewType<PtrAndLen> { |
| using type = std::string; |
| }; |
| |
| template <typename Key> |
| using KeyMap = internal::KeyMapBase< |
| internal::KeyForBase<typename FromViewType<Key>::type>>; |
| |
| void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) { |
| switch (value.tag) { |
| case MapValueTag::kBool: |
| *size = sizeof(bool); |
| *alignment = alignof(bool); |
| break; |
| case MapValueTag::kU32: |
| *size = sizeof(uint32_t); |
| *alignment = alignof(uint32_t); |
| break; |
| case MapValueTag::kU64: |
| *size = sizeof(uint64_t); |
| *alignment = alignof(uint64_t); |
| break; |
| case MapValueTag::kFloat: |
| *size = sizeof(float); |
| *alignment = alignof(float); |
| break; |
| case MapValueTag::kDouble: |
| *size = sizeof(double); |
| *alignment = alignof(double); |
| break; |
| case MapValueTag::kString: |
| *size = sizeof(std::string); |
| *alignment = alignof(std::string); |
| break; |
| case MapValueTag::kMessage: |
| internal::RustMapHelper::GetSizeAndAlignment(value.message, size, |
| alignment); |
| break; |
| default: |
| ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; |
| } |
| } |
| |
| internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, MapValue value) { |
| // Each map node consists of a NodeBase followed by a std::pair<Key, Value>. |
| // We need to compute the offset of the value and the total size of the node. |
| size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; |
| uint16_t value_size; |
| uint8_t value_alignment; |
| GetSizeAndAlignment(value, &value_size, &value_alignment); |
| // Round node_and_key_size up to the nearest multiple of value_alignment. |
| uint16_t offset = |
| (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; |
| |
| size_t overall_alignment = std::max(alignof(internal::NodeBase), |
| static_cast<size_t>(value_alignment)); |
| // Round up size to nearest multiple of overall_alignment. |
| size_t overall_size = |
| (((offset + value_size - 1) / overall_alignment) + 1) * overall_alignment; |
| |
| return internal::RustMapHelper::MakeSizeInfo(overall_size, offset); |
| } |
| |
| template <typename Key> |
| void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, |
| internal::MapNodeSizeInfoT size_info, |
| bool destroy_message) { |
| if constexpr (std::is_same<Key, PtrAndLen>::value) { |
| static_cast<std::string*>(node->GetVoidKey())->~basic_string(); |
| } |
| if (destroy_message) { |
| internal::RustMapHelper::DestroyMessage( |
| static_cast<MessageLite*>(node->GetVoidValue(size_info))); |
| } |
| internal::RustMapHelper::DeallocNode(m, node, size_info); |
| } |
| |
| void InitializeMessageValue(void* raw_ptr, MessageLite* msg) { |
| MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr); |
| auto* full_msg = DynamicCastMessage<Message>(new_msg); |
| |
| // If we are working with a full (non-lite) proto, we reflectively swap the |
| // value into place. Otherwise, we have to perform a copy. |
| if (full_msg != nullptr) { |
| full_msg->GetReflection()->Swap(full_msg, DynamicCastMessage<Message>(msg)); |
| } else { |
| new_msg->CheckTypeAndMergeFrom(*msg); |
| } |
| delete msg; |
| } |
| |
| template <typename Key> |
| bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) { |
| internal::MapNodeSizeInfoT size_info = |
| GetSizeInfo(sizeof(typename FromViewType<Key>::type), value); |
| internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); |
| if constexpr (std::is_same<Key, PtrAndLen>::value) { |
| new (node->GetVoidKey()) std::string(key.ptr, key.len); |
| } else { |
| *static_cast<Key*>(node->GetVoidKey()) = key; |
| } |
| |
| void* value_ptr = node->GetVoidValue(size_info); |
| switch (value.tag) { |
| case MapValueTag::kBool: |
| *static_cast<bool*>(value_ptr) = value.b; |
| break; |
| case MapValueTag::kU32: |
| *static_cast<uint32_t*>(value_ptr) = value.u32; |
| break; |
| case MapValueTag::kU64: |
| *static_cast<uint64_t*>(value_ptr) = value.u64; |
| break; |
| case MapValueTag::kFloat: |
| *static_cast<float*>(value_ptr) = value.f32; |
| break; |
| case MapValueTag::kDouble: |
| *static_cast<double*>(value_ptr) = value.f64; |
| break; |
| case MapValueTag::kString: |
| new (value_ptr) std::string(std::move(*value.s)); |
| delete value.s; |
| break; |
| case MapValueTag::kMessage: |
| InitializeMessageValue(value_ptr, value.message); |
| break; |
| default: |
| ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; |
| } |
| |
| node = internal::RustMapHelper::InsertOrReplaceNode( |
| static_cast<KeyMap<Key>*>(m), node); |
| if (node == nullptr) { |
| return true; |
| } |
| DestroyMapNode<Key>(m, node, size_info, value.tag == MapValueTag::kMessage); |
| return false; |
| } |
| |
| template <typename Map, typename Key, |
| typename = typename std::enable_if< |
| !std::is_same<Key, google::protobuf::rust::PtrAndLen>::value>::type> |
| internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) { |
| return internal::RustMapHelper::FindHelper( |
| m, static_cast<internal::KeyForBase<Key>>(key)); |
| } |
| |
| template <typename Map> |
| internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, |
| google::protobuf::rust::PtrAndLen key) { |
| return internal::RustMapHelper::FindHelper( |
| m, absl::string_view(key.ptr, key.len)); |
| } |
| |
| void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) { |
| output.tag = tag; |
| switch (tag) { |
| case MapValueTag::kBool: |
| output.b = *static_cast<const bool*>(data); |
| break; |
| case MapValueTag::kU32: |
| output.u32 = *static_cast<const uint32_t*>(data); |
| break; |
| case MapValueTag::kU64: |
| output.u64 = *static_cast<const uint64_t*>(data); |
| break; |
| case MapValueTag::kFloat: |
| output.f32 = *static_cast<const float*>(data); |
| break; |
| case MapValueTag::kDouble: |
| output.f64 = *static_cast<const double*>(data); |
| break; |
| case MapValueTag::kString: |
| output.s = static_cast<std::string*>(data); |
| break; |
| case MapValueTag::kMessage: |
| output.message = static_cast<MessageLite*>(data); |
| break; |
| default: |
| ABSL_DLOG(FATAL) << "Unexpected MapValueTag"; |
| } |
| } |
| |
| template <typename Key> |
| bool Get(internal::UntypedMapBase* m, MapValue prototype, Key key, |
| MapValue* value) { |
| internal::MapNodeSizeInfoT size_info = |
| GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype); |
| auto* map_base = static_cast<KeyMap<Key>*>(m); |
| internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); |
| if (result.node == nullptr) { |
| return false; |
| } |
| PopulateMapValue(prototype.tag, result.node->GetVoidValue(size_info), *value); |
| return true; |
| } |
| |
| template <typename Key> |
| bool Remove(internal::UntypedMapBase* m, MapValue prototype, Key key) { |
| internal::MapNodeSizeInfoT size_info = |
| GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype); |
| auto* map_base = static_cast<KeyMap<Key>*>(m); |
| internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); |
| if (result.node == nullptr) { |
| return false; |
| } |
| internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); |
| DestroyMapNode<Key>(m, result.node, size_info, |
| prototype.tag == MapValueTag::kMessage); |
| return true; |
| } |
| |
| template <typename Key> |
| void IterGet(const internal::UntypedMapIterator* iter, MapValue prototype, |
| Key* key, MapValue* value) { |
| internal::MapNodeSizeInfoT size_info = |
| GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype); |
| internal::NodeBase* node = iter->node_; |
| if constexpr (std::is_same<Key, PtrAndLen>::value) { |
| const std::string* s = static_cast<const std::string*>(node->GetVoidKey()); |
| *key = PtrAndLen{s->data(), s->size()}; |
| } else { |
| *key = *static_cast<const Key*>(node->GetVoidKey()); |
| } |
| PopulateMapValue(prototype.tag, node->GetVoidValue(size_info), *value); |
| } |
| |
| // Returns the size of the key in the map entry, given the key used for FFI. |
| // The map entry key and FFI key are always the same, except in the case of |
| // string and bytes. |
| template <typename Key> |
| size_t KeySize() { |
| if constexpr (std::is_same<Key, google::protobuf::rust::PtrAndLen>::value) { |
| return sizeof(std::string); |
| } else { |
| return sizeof(Key); |
| } |
| } |
| |
| } // namespace |
| } // namespace rust |
| } // namespace protobuf |
| } // namespace google |
| |
| extern "C" { |
| |
| void proto2_rust_thunk_UntypedMapIterator_increment( |
| google::protobuf::internal::UntypedMapIterator* iter) { |
| iter->PlusPlus(); |
| } |
| |
| google::protobuf::internal::UntypedMapBase* proto2_rust_map_new( |
| google::protobuf::rust::MapValue key_prototype, |
| google::protobuf::rust::MapValue value_prototype) { |
| return new google::protobuf::internal::UntypedMapBase( |
| /* arena = */ nullptr, |
| google::protobuf::internal::UntypedMapBase::GetTypeInfoDynamic( |
| key_prototype.tag, value_prototype.tag, |
| value_prototype.tag == google::protobuf::rust::MapValueTag::kMessage |
| ? value_prototype.message |
| : nullptr)); |
| } |
| |
| size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { |
| return m->size(); |
| } |
| |
| google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( |
| google::protobuf::internal::UntypedMapBase* m) { |
| return m->begin(); |
| } |
| |
| void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m) { |
| m->ClearTable(false, nullptr); |
| delete m; |
| } |
| |
| void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m) { |
| m->ClearTable(true, nullptr); |
| } |
| |
| #define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ |
| bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ |
| cpp_type key, \ |
| google::protobuf::rust::MapValue value) { \ |
| return google::protobuf::rust::Insert(m, key, value); \ |
| } \ |
| \ |
| bool proto2_rust_map_get_##suffix( \ |
| google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapValue prototype, \ |
| cpp_type key, google::protobuf::rust::MapValue* value) { \ |
| return google::protobuf::rust::Get(m, prototype, key, value); \ |
| } \ |
| \ |
| bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ |
| google::protobuf::rust::MapValue prototype, \ |
| cpp_type key) { \ |
| return google::protobuf::rust::Remove(m, prototype, key); \ |
| } \ |
| \ |
| void proto2_rust_map_iter_get_##suffix( \ |
| const google::protobuf::internal::UntypedMapIterator* iter, \ |
| google::protobuf::rust::MapValue prototype, cpp_type* key, \ |
| google::protobuf::rust::MapValue* value) { \ |
| return google::protobuf::rust::IterGet(iter, prototype, key, value); \ |
| } |
| |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint32_t, u32) |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int64_t, i64) |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64) |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool) |
| DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString) |
| |
| } // extern "C" |