Implement Maps for scalar types for the C++ kernel
This CL implements Maps for scalar types for the C++ runtime. It's orthogonal to cl/580453646. This CL is constrained by having to force template instantiation of proto2::Map<K, V>. Put differently, a Rust protobuf::Map<K, V> implementation needs to call 'extern "C"' functions with both key and value type in the function name (e.g. __pb_rust_Map_i32_f64_get()). We use macros to generate a Map implementation for every (K,V)-pair. An alternative would have been to use vtables.
Luckily a key in a protobuf map can only be integer types, bool and string. So the number of key types is bounded by the specification, while the number of value types is not i.e. any protobuf message can be a value in a map. Given these constraints we introduce one 'MapKeyOps' trait per key type e.g. MapKeyBOOLOps or MapKeyI32Ops. These traits need to be implemented for every value type e.g. 'impl MapKeyBOOLOps for i32' will implement 'Map::<bool, i32>'. In particular the MapKeyOps traits can also be implemented for generated messages without violating the orphan rule.
This CL also contains significant changes to the UPB runtime so that both upb.rs and cpp.rs export a similar interface to simplify the implementation in map.rs and the generated code.
This CL does not yet implement the Proxied trait.
PiperOrigin-RevId: 582951914
diff --git a/rust/BUILD b/rust/BUILD
index 2bd06ae..bc6a5cc 100644
--- a/rust/BUILD
+++ b/rust/BUILD
@@ -56,6 +56,7 @@
"shared.rs",
"string.rs",
"vtable.rs",
+ "map.rs",
]
# The Rust Protobuf runtime using the upb kernel.
@@ -65,11 +66,11 @@
# setting.
rust_library(
name = "protobuf_upb",
- srcs = PROTOBUF_SHARED + [
- "map.rs",
- "upb.rs",
- ],
+ srcs = PROTOBUF_SHARED + ["upb.rs"],
crate_root = "shared.rs",
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
rustc_flags = ["--cfg=upb_kernel"],
visibility = [
"//src/google/protobuf:__subpackages__",
diff --git a/rust/cpp.rs b/rust/cpp.rs
index 4346d41..11fd634 100644
--- a/rust/cpp.rs
+++ b/rust/cpp.rs
@@ -7,7 +7,7 @@
// Rust Protobuf runtime using the C++ kernel.
-use crate::__internal::{Private, RawArena, RawMessage, RawRepeatedField};
+use crate::__internal::{Private, RawArena, RawMap, RawMessage, RawRepeatedField};
use paste::paste;
use std::alloc::Layout;
use std::cell::UnsafeCell;
@@ -319,6 +319,143 @@
}
}
+#[derive(Debug)]
+pub struct Map<'msg, K: ?Sized, V: ?Sized> {
+ inner: MapInner<'msg>,
+ _phantom_key: PhantomData<&'msg mut K>,
+ _phantom_value: PhantomData<&'msg mut V>,
+}
+
+#[derive(Clone, Copy, Debug)]
+pub struct MapInner<'msg> {
+ pub raw: RawMap,
+ pub _phantom: PhantomData<&'msg ()>,
+}
+
+// These use manual impls instead of derives to avoid unnecessary bounds on `K`
+// and `V`. This problem is referred to as "perfect derive".
+// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/
+impl<'msg, K: ?Sized, V: ?Sized> Copy for Map<'msg, K, V> {}
+impl<'msg, K: ?Sized, V: ?Sized> Clone for Map<'msg, K, V> {
+ fn clone(&self) -> Map<'msg, K, V> {
+ *self
+ }
+}
+
+impl<'msg, K: ?Sized, V: ?Sized> Map<'msg, K, V> {
+ pub fn from_inner(_private: Private, inner: MapInner<'msg>) -> Self {
+ Map { inner, _phantom_key: PhantomData, _phantom_value: PhantomData }
+ }
+}
+
+macro_rules! impl_scalar_map_values {
+ ($kt:ty, $trait:ident for $($t:ty),*) => {
+ paste! { $(
+ extern "C" {
+ fn [< __pb_rust_Map_ $kt _ $t _new >]() -> RawMap;
+ fn [< __pb_rust_Map_ $kt _ $t _clear >](m: RawMap);
+ fn [< __pb_rust_Map_ $kt _ $t _size >](m: RawMap) -> usize;
+ fn [< __pb_rust_Map_ $kt _ $t _insert >](m: RawMap, key: $kt, value: $t);
+ fn [< __pb_rust_Map_ $kt _ $t _get >](m: RawMap, key: $kt, value: *mut $t) -> bool;
+ fn [< __pb_rust_Map_ $kt _ $t _remove >](m: RawMap, key: $kt, value: *mut $t) -> bool;
+ }
+ impl $trait for $t {
+ fn new_map() -> RawMap {
+ unsafe { [< __pb_rust_Map_ $kt _ $t _new >]() }
+ }
+
+ fn clear(m: RawMap) {
+ unsafe { [< __pb_rust_Map_ $kt _ $t _clear >](m) }
+ }
+
+ fn size(m: RawMap) -> usize {
+ unsafe { [< __pb_rust_Map_ $kt _ $t _size >](m) }
+ }
+
+ fn insert(m: RawMap, key: $kt, value: $t) {
+ unsafe { [< __pb_rust_Map_ $kt _ $t _insert >](m, key, value) }
+ }
+
+ fn get(m: RawMap, key: $kt) -> Option<$t> {
+ let mut val: $t = Default::default();
+ let found = unsafe { [< __pb_rust_Map_ $kt _ $t _get >](m, key, &mut val) };
+ if !found {
+ return None;
+ }
+ Some(val)
+ }
+
+ fn remove(m: RawMap, key: $kt) -> Option<$t> {
+ let mut val: $t = Default::default();
+ let removed =
+ unsafe { [< __pb_rust_Map_ $kt _ $t _remove >](m, key, &mut val) };
+ if !removed {
+ return None;
+ }
+ Some(val)
+ }
+ }
+ )* }
+ }
+}
+
+macro_rules! impl_scalar_maps {
+ ($($t:ty),*) => {
+ paste! { $(
+ pub trait [< MapWith $t:camel KeyOps >] {
+ fn new_map() -> RawMap;
+ fn clear(m: RawMap);
+ fn size(m: RawMap) -> usize;
+ fn insert(m: RawMap, key: $t, value: Self);
+ fn get(m: RawMap, key: $t) -> Option<Self>
+ where
+ Self: Sized;
+ fn remove(m: RawMap, key: $t) -> Option<Self>
+ where
+ Self: Sized;
+ }
+
+ impl_scalar_map_values!(
+ $t, [< MapWith $t:camel KeyOps >] for i32, u32, f32, f64, bool, u64, i64
+ );
+
+ impl<'msg, V: [< MapWith $t:camel KeyOps >]> Map<'msg, $t, V> {
+ pub fn new() -> Self {
+ let inner = MapInner { raw: V::new_map(), _phantom: PhantomData };
+ Map {
+ inner,
+ _phantom_key: PhantomData,
+ _phantom_value: PhantomData
+ }
+ }
+
+ pub fn size(&self) -> usize {
+ V::size(self.inner.raw)
+ }
+
+ pub fn clear(&mut self) {
+ V::clear(self.inner.raw)
+ }
+
+ pub fn get(&self, key: $t) -> Option<V> {
+ V::get(self.inner.raw, key)
+ }
+
+ pub fn remove(&mut self, key: $t) -> Option<V> {
+ V::remove(self.inner.raw, key)
+ }
+
+ pub fn insert(&mut self, key: $t, value: V) -> bool {
+ V::insert(self.inner.raw, key, value);
+ true
+ }
+ }
+ )* }
+ }
+}
+
+impl_scalar_maps!(i32, u32, bool, u64, i64);
+
#[cfg(test)]
mod tests {
use super::*;
@@ -362,4 +499,44 @@
r.push(true);
assert_that!(r.get(0), eq(Some(true)));
}
+
+ #[test]
+ fn i32_i32_map() {
+ let mut map = Map::<'_, i32, i32>::new();
+ assert_that!(map.size(), eq(0));
+
+ assert_that!(map.insert(1, 2), eq(true));
+ assert_that!(map.get(1), eq(Some(2)));
+ assert_that!(map.get(3), eq(None));
+ assert_that!(map.size(), eq(1));
+
+ assert_that!(map.remove(1), eq(Some(2)));
+ assert_that!(map.size(), eq(0));
+ assert_that!(map.remove(1), eq(None));
+
+ assert_that!(map.insert(4, 5), eq(true));
+ assert_that!(map.insert(6, 7), eq(true));
+ map.clear();
+ assert_that!(map.size(), eq(0));
+ }
+
+ #[test]
+ fn i64_f64_map() {
+ let mut map = Map::<'_, i64, f64>::new();
+ assert_that!(map.size(), eq(0));
+
+ assert_that!(map.insert(1, 2.5), eq(true));
+ assert_that!(map.get(1), eq(Some(2.5)));
+ assert_that!(map.get(3), eq(None));
+ assert_that!(map.size(), eq(1));
+
+ assert_that!(map.remove(1), eq(Some(2.5)));
+ assert_that!(map.size(), eq(0));
+ assert_that!(map.remove(1), eq(None));
+
+ assert_that!(map.insert(4, 5.1), eq(true));
+ assert_that!(map.insert(6, 7.2), eq(true));
+ map.clear();
+ assert_that!(map.size(), eq(0));
+ }
}
diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc
index 1381611..97f2c3c 100644
--- a/rust/cpp_kernel/cpp_api.cc
+++ b/rust/cpp_kernel/cpp_api.cc
@@ -1,3 +1,4 @@
+#include "google/protobuf/map.h"
#include "google/protobuf/repeated_field.h"
extern "C" {
@@ -36,4 +37,61 @@
expose_repeated_field_methods(int64_t, i64);
#undef expose_repeated_field_methods
+
+#define expose_scalar_map_methods(key_ty, rust_key_ty, value_ty, \
+ rust_value_ty) \
+ google::protobuf::Map<key_ty, value_ty>* \
+ __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_new() { \
+ return new google::protobuf::Map<key_ty, value_ty>(); \
+ } \
+ void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_clear( \
+ google::protobuf::Map<key_ty, value_ty>* m) { \
+ m->clear(); \
+ } \
+ size_t __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_size( \
+ google::protobuf::Map<key_ty, value_ty>* m) { \
+ return m->size(); \
+ } \
+ void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_insert( \
+ google::protobuf::Map<key_ty, value_ty>* m, key_ty key, value_ty val) { \
+ (*m)[key] = val; \
+ } \
+ bool __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_get( \
+ google::protobuf::Map<key_ty, value_ty>* m, key_ty key, value_ty* value) { \
+ auto it = m->find(key); \
+ if (it == m->end()) { \
+ return false; \
+ } \
+ *value = it->second; \
+ return true; \
+ } \
+ bool __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_remove( \
+ google::protobuf::Map<key_ty, value_ty>* m, key_ty key, value_ty* value) { \
+ auto it = m->find(key); \
+ if (it == m->end()) { \
+ return false; \
+ } else { \
+ *value = it->second; \
+ m->erase(it); \
+ return true; \
+ } \
+ }
+
+#define expose_scalar_map_methods_for_key_type(key_ty, rust_key_ty) \
+ expose_scalar_map_methods(key_ty, rust_key_ty, int32_t, i32); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, uint32_t, u32); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, float, f32); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, double, f64); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, bool, bool); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, uint64_t, u64); \
+ expose_scalar_map_methods(key_ty, rust_key_ty, int64_t, i64);
+
+expose_scalar_map_methods_for_key_type(int32_t, i32);
+expose_scalar_map_methods_for_key_type(uint32_t, u32);
+expose_scalar_map_methods_for_key_type(bool, bool);
+expose_scalar_map_methods_for_key_type(uint64_t, u64);
+expose_scalar_map_methods_for_key_type(int64_t, i64);
+
+#undef expose_scalar_map_methods
+#undef expose_map_methods
}
diff --git a/rust/map.rs b/rust/map.rs
index 450b854..ab54698 100644
--- a/rust/map.rs
+++ b/rust/map.rs
@@ -7,8 +7,12 @@
use crate::{
__internal::Private,
- __runtime::{Map, MapInner, MapValueType},
+ __runtime::{
+ Map, MapInner, MapWithBoolKeyOps, MapWithI32KeyOps, MapWithI64KeyOps, MapWithU32KeyOps,
+ MapWithU64KeyOps,
+ },
};
+use paste::paste;
#[derive(Clone, Copy)]
#[repr(transparent)]
@@ -26,14 +30,6 @@
pub fn from_inner(_private: Private, inner: MapInner<'a>) -> Self {
Self { inner: Map::<'a, K, V>::from_inner(_private, inner) }
}
-
- pub fn len(&self) -> usize {
- self.inner.len()
- }
-
- pub fn is_empty(&self) -> bool {
- self.len() == 0
- }
}
impl<'a, K: ?Sized, V: ?Sized> MapMut<'a, K, V> {
@@ -44,14 +40,22 @@
macro_rules! impl_scalar_map_keys {
($(key_type $type:ty;)*) => {
- $(
- impl<'a, V: MapValueType> MapView<'a, $type, V> {
+ paste! { $(
+ impl<'a, V: [< MapWith $type:camel KeyOps >]> MapView<'a, $type, V> {
pub fn get(&self, key: $type) -> Option<V> {
self.inner.get(key)
}
+
+ pub fn len(&self) -> usize {
+ self.inner.size()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
}
- impl<'a, V: MapValueType> MapMut<'a, $type, V> {
+ impl<'a, V: [< MapWith $type:camel KeyOps >]> MapMut<'a, $type, V> {
pub fn insert(&mut self, key: $type, value: V) -> bool {
self.inner.insert(key, value)
}
@@ -64,7 +68,7 @@
self.inner.clear()
}
}
- )*
+ )* }
};
}
diff --git a/rust/shared.rs b/rust/shared.rs
index 5aa054c..26f0629 100644
--- a/rust/shared.rs
+++ b/rust/shared.rs
@@ -17,7 +17,6 @@
/// These are the items protobuf users can access directly.
#[doc(hidden)]
pub mod __public {
- #[cfg(upb_kernel)]
pub use crate::map::{MapMut, MapView};
pub use crate::optional::{AbsentField, FieldEntry, Optional, PresentField};
pub use crate::primitive::{PrimitiveMut, SingularPrimitiveMut};
@@ -46,7 +45,6 @@
pub mod __runtime;
mod macros;
-#[cfg(upb_kernel)]
mod map;
mod optional;
mod primitive;
diff --git a/rust/test/BUILD b/rust/test/BUILD
index 291e0df..03b16fb 100644
--- a/rust/test/BUILD
+++ b/rust/test/BUILD
@@ -313,6 +313,21 @@
deps = [":nested_proto"],
)
+cc_proto_library(
+ name = "map_unittest_cc_proto",
+ testonly = True,
+ deps = ["//src/google/protobuf:map_unittest_proto"],
+)
+
+rust_cc_proto_library(
+ name = "map_unittest_cc_rust_proto",
+ testonly = True,
+ visibility = [
+ "//rust/test/shared:__subpackages__",
+ ],
+ deps = [":map_unittest_cc_proto"],
+)
+
rust_upb_proto_library(
name = "map_unittest_upb_rust_proto",
testonly = True,
diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD
index 06f5633..04988de 100644
--- a/rust/test/shared/BUILD
+++ b/rust/test/shared/BUILD
@@ -281,11 +281,31 @@
)
rust_test(
+ name = "accessors_map_cpp_test",
+ srcs = ["accessors_map_test.rs"],
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
+ tags = [
+ # TODO: Enable testing on arm once we support sanitizers for Rust on Arm.
+ "not_build:arm",
+ ],
+ deps = [
+ "@crate_index//:googletest",
+ "//rust/test:map_unittest_cc_rust_proto",
+ ],
+)
+
+rust_test(
name = "accessors_map_upb_test",
srcs = ["accessors_map_test.rs"],
proc_macro_deps = [
"@crate_index//:paste",
],
+ tags = [
+ # TODO: Enable testing on arm once we support sanitizers for Rust on Arm.
+ "not_build:arm",
+ ],
deps = [
"@crate_index//:googletest",
"//rust/test:map_unittest_upb_rust_proto",
diff --git a/rust/upb.rs b/rust/upb.rs
index 136f4ed..8f713da 100644
--- a/rust/upb.rs
+++ b/rust/upb.rs
@@ -8,6 +8,7 @@
//! UPB FFI wrapper code for use by Rust Protobuf.
use crate::__internal::{Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField};
+use paste::paste;
use std::alloc;
use std::alloc::Layout;
use std::cell::UnsafeCell;
@@ -376,7 +377,7 @@
}
macro_rules! impl_repeated_primitives {
- ($(($rs_type:ty, $union_field:ident, $upb_tag:expr)),*) => {
+ ($(($rs_type:ty, $ufield:ident, $upb_tag:expr)),*) => {
$(
impl<'msg> RepeatedField<'msg, $rs_type> {
#[allow(dead_code)]
@@ -392,7 +393,7 @@
pub fn push(&mut self, val: $rs_type) {
unsafe { upb_Array_Append(
self.inner.raw,
- upb_MessageValue { $union_field: val },
+ upb_MessageValue { $ufield: val },
self.inner.arena.raw(),
) }
}
@@ -400,7 +401,7 @@
if i >= self.len() {
None
} else {
- unsafe { Some(upb_Array_Get(self.inner.raw, i).$union_field) }
+ unsafe { Some(upb_Array_Get(self.inner.raw, i).$ufield) }
}
}
pub fn set(&self, i: usize, val: $rs_type) {
@@ -410,7 +411,7 @@
unsafe { upb_Array_Set(
self.inner.raw,
i,
- upb_MessageValue { $union_field: val },
+ upb_MessageValue { $ufield: val },
) }
}
pub fn copy_from(&mut self, src: &RepeatedField<'_, $rs_type>) {
@@ -511,137 +512,132 @@
}
impl<'msg, K: ?Sized, V: ?Sized> Map<'msg, K, V> {
- pub fn len(&self) -> usize {
- unsafe { upb_Map_Size(self.inner.raw) }
- }
-
- pub fn is_empty(&self) -> bool {
- self.len() == 0
- }
-
pub fn from_inner(_private: Private, inner: MapInner<'msg>) -> Self {
Map { inner, _phantom_key: PhantomData, _phantom_value: PhantomData }
}
-
- pub fn clear(&mut self) {
- unsafe { upb_Map_Clear(self.inner.raw) }
- }
}
-/// # Safety
-/// Implementers of this trait must ensure that `pack_message_value` returns
-/// a `upb_MessageValue` with the active variant indicated by `Self`.
-pub unsafe trait MapType {
- /// # Safety
- /// The active variant of `outer` must be the `type PrimitiveValue`
- unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self;
-
- fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue;
-
- fn upb_ctype(_private: Private) -> UpbCType;
-
- fn zero_value(_private: Private) -> Self;
-}
-
-/// Types implementing this trait can be used as map keys.
-pub trait MapKeyType: MapType {}
-
-/// Types implementing this trait can be used as map values.
-pub trait MapValueType: MapType {}
-
-macro_rules! impl_scalar_map_value_types {
- ($($type:ty, $union_field:ident, $upb_tag:expr, $zero_val:literal;)*) => {
- $(
- unsafe impl MapType for $type {
- unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self {
- unsafe { outer.$union_field }
+macro_rules! impl_scalar_map_for_key_type {
+ ($key_t:ty, $key_ufield:ident, $key_upb_tag:expr, $trait:ident for $($t:ty, $ufield:ident, $upb_tag:expr, $zero_val:literal;)*) => {
+ paste! { $(
+ impl $trait for $t {
+ fn new_map(a: RawArena) -> RawMap {
+ unsafe { upb_Map_New(a, $key_upb_tag, $upb_tag) }
}
- fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue {
- upb_MessageValue { $union_field: inner }
+ fn clear(m: RawMap) {
+ unsafe { upb_Map_Clear(m) }
}
- fn upb_ctype(_private: Private) -> UpbCType {
- $upb_tag
+ fn size(m: RawMap) -> usize {
+ unsafe { upb_Map_Size(m) }
}
- fn zero_value(_private: Private) -> Self {
- $zero_val
+ fn insert(m: RawMap, a: RawArena, key: $key_t, value: $t) -> bool {
+ unsafe {
+ upb_Map_Set(
+ m,
+ upb_MessageValue { $key_ufield: key },
+ upb_MessageValue { $ufield: value},
+ a
+ )
+ }
+ }
+
+ fn get(m: RawMap, key: $key_t) -> Option<$t> {
+ let mut val = upb_MessageValue { $ufield: $zero_val };
+ let found = unsafe {
+ upb_Map_Get(m, upb_MessageValue { $key_ufield: key }, &mut val)
+ };
+ if !found {
+ return None;
+ }
+ Some(unsafe { val.$ufield })
+ }
+
+ fn remove(m: RawMap, key: $key_t) -> Option<$t> {
+ let mut val = upb_MessageValue { $ufield: $zero_val };
+ let removed = unsafe {
+ upb_Map_Delete(m, upb_MessageValue { $key_ufield: key }, &mut val)
+ };
+ if !removed {
+ return None;
+ }
+ Some(unsafe { val.$ufield })
}
}
-
- impl MapValueType for $type {}
- )*
- };
+ )* }
+ }
}
-impl_scalar_map_value_types!(
- f32, float_val, UpbCType::Float, 0f32;
- f64, double_val, UpbCType::Double, 0f64;
- i32, int32_val, UpbCType::Int32, 0i32;
- u32, uint32_val, UpbCType::UInt32, 0u32;
- i64, int64_val, UpbCType::Int64, 0i64;
- u64, uint64_val, UpbCType::UInt64, 0u64;
- bool, bool_val, UpbCType::Bool, false;
+macro_rules! impl_scalar_map_for_key_types {
+ ($($t:ty, $ufield:ident, $upb_tag:expr;)*) => {
+ paste! { $(
+ pub trait [< MapWith $t:camel KeyOps >] {
+ fn new_map(a: RawArena) -> RawMap;
+ fn clear(m: RawMap);
+ fn size(m: RawMap) -> usize;
+ fn insert(m: RawMap, a: RawArena, key: $t, value: Self) -> bool;
+ fn get(m: RawMap, key: $t) -> Option<Self>
+ where
+ Self: Sized;
+ fn remove(m: RawMap, key: $t) -> Option<Self>
+ where
+ Self: Sized;
+ }
+
+ impl_scalar_map_for_key_type!($t, $ufield, $upb_tag, [< MapWith $t:camel KeyOps >] for
+ f32, float_val, UpbCType::Float, 0f32;
+ f64, double_val, UpbCType::Double, 0f64;
+ i32, int32_val, UpbCType::Int32, 0i32;
+ u32, uint32_val, UpbCType::UInt32, 0u32;
+ i64, int64_val, UpbCType::Int64, 0i64;
+ u64, uint64_val, UpbCType::UInt64, 0u64;
+ bool, bool_val, UpbCType::Bool, false;
+ );
+
+ impl<'msg, V: [< MapWith $t:camel KeyOps >]> Map<'msg, $t, V> {
+ pub fn new(arena: &'msg mut Arena) -> Self {
+ let inner = MapInner { raw: V::new_map(arena.raw()), arena };
+ Map {
+ inner,
+ _phantom_key: PhantomData,
+ _phantom_value: PhantomData
+ }
+ }
+
+ pub fn size(&self) -> usize {
+ V::size(self.inner.raw)
+ }
+
+ pub fn clear(&mut self) {
+ V::clear(self.inner.raw)
+ }
+
+ pub fn get(&self, key: $t) -> Option<V> {
+ V::get(self.inner.raw, key)
+ }
+
+ pub fn remove(&mut self, key: $t) -> Option<V> {
+ V::remove(self.inner.raw, key)
+ }
+
+ pub fn insert(&mut self, key: $t, value: V) -> bool {
+ V::insert(self.inner.raw, self.inner.arena.raw(), key, value)
+ }
+ }
+ )* }
+ }
+}
+
+impl_scalar_map_for_key_types!(
+ i32, int32_val, UpbCType::Int32;
+ u32, uint32_val, UpbCType::UInt32;
+ i64, int64_val, UpbCType::Int64;
+ u64, uint64_val, UpbCType::UInt64;
+ bool, bool_val, UpbCType::Bool;
);
-macro_rules! impl_scalar_map_key_types {
- ($($type:ty;)*) => {
- $(
- impl MapKeyType for $type {}
- )*
- };
-}
-
-impl_scalar_map_key_types!(
- i32; u32; i64; u64; bool;
-);
-
-impl<'msg, K: MapKeyType, V: MapValueType> Map<'msg, K, V> {
- pub fn new(arena: &'msg Arena) -> Self {
- unsafe {
- let raw_map = upb_Map_New(arena.raw(), K::upb_ctype(Private), V::upb_ctype(Private));
- Map {
- inner: MapInner { raw: raw_map, arena },
- _phantom_key: PhantomData,
- _phantom_value: PhantomData,
- }
- }
- }
-
- pub fn get(&self, key: K) -> Option<V> {
- let mut val = V::pack_message_value(Private, V::zero_value(Private));
- let found =
- unsafe { upb_Map_Get(self.inner.raw, K::pack_message_value(Private, key), &mut val) };
- if !found {
- return None;
- }
- Some(unsafe { V::unpack_message_value(Private, val) })
- }
-
- pub fn insert(&mut self, key: K, value: V) -> bool {
- unsafe {
- upb_Map_Set(
- self.inner.raw,
- K::pack_message_value(Private, key),
- V::pack_message_value(Private, value),
- self.inner.arena.raw(),
- )
- }
- }
-
- pub fn remove(&mut self, key: K) -> Option<V> {
- let mut val = V::pack_message_value(Private, V::zero_value(Private));
- let removed = unsafe {
- upb_Map_Delete(self.inner.raw, K::pack_message_value(Private, key), &mut val)
- };
- if !removed {
- return None;
- }
- Some(unsafe { V::unpack_message_value(Private, val) })
- }
-}
-
extern "C" {
fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
fn upb_Map_Size(map: RawMap) -> usize;
@@ -720,43 +716,43 @@
#[test]
fn i32_i32_map() {
- let arena = Arena::new();
- let mut map = Map::<'_, i32, i32>::new(&arena);
- assert_that!(map.len(), eq(0));
+ let mut arena = Arena::new();
+ let mut map = Map::<'_, i32, i32>::new(&mut arena);
+ assert_that!(map.size(), eq(0));
assert_that!(map.insert(1, 2), eq(true));
assert_that!(map.get(1), eq(Some(2)));
assert_that!(map.get(3), eq(None));
- assert_that!(map.len(), eq(1));
+ assert_that!(map.size(), eq(1));
assert_that!(map.remove(1), eq(Some(2)));
- assert_that!(map.len(), eq(0));
+ assert_that!(map.size(), eq(0));
assert_that!(map.remove(1), eq(None));
assert_that!(map.insert(4, 5), eq(true));
assert_that!(map.insert(6, 7), eq(true));
map.clear();
- assert_that!(map.len(), eq(0));
+ assert_that!(map.size(), eq(0));
}
#[test]
fn i64_f64_map() {
- let arena = Arena::new();
- let mut map = Map::<'_, i64, f64>::new(&arena);
- assert_that!(map.len(), eq(0));
+ let mut arena = Arena::new();
+ let mut map = Map::<'_, i64, f64>::new(&mut arena);
+ assert_that!(map.size(), eq(0));
assert_that!(map.insert(1, 2.5), eq(true));
assert_that!(map.get(1), eq(Some(2.5)));
assert_that!(map.get(3), eq(None));
- assert_that!(map.len(), eq(1));
+ assert_that!(map.size(), eq(1));
assert_that!(map.remove(1), eq(Some(2.5)));
- assert_that!(map.len(), eq(0));
+ assert_that!(map.size(), eq(0));
assert_that!(map.remove(1), eq(None));
assert_that!(map.insert(4, 5.1), eq(true));
assert_that!(map.insert(6, 7.2), eq(true));
map.clear();
- assert_that!(map.len(), eq(0));
+ assert_that!(map.size(), eq(0));
}
}
diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
index dcefbe8..442de3c 100644
--- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
+++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
@@ -111,6 +111,7 @@
~Map() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
+ void InThunkCc(Context<FieldDescriptor> field) const override;
};
} // namespace rust
diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc
index f7eda84..1df6b23 100644
--- a/src/google/protobuf/compiler/rust/accessors/map.cc
+++ b/src/google/protobuf/compiler/rust/accessors/map.cc
@@ -5,6 +5,7 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
+#include "google/protobuf/compiler/cpp/helpers.h"
#include "google/protobuf/compiler/rust/accessors/accessor_generator.h"
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/compiler/rust/naming.h"
@@ -24,33 +25,55 @@
{"Key", PrimitiveRsTypeName(key_type)},
{"Value", PrimitiveRsTypeName(value_type)},
{"getter_thunk", Thunk(field, "get")},
- {"getter_thunk_mut", Thunk(field, "get_mut")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
- pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> {
- let inner = unsafe {
- $getter_thunk$(self.inner.msg)
- }.map_or_else(|| unsafe {$pbr$::empty_map()}, |raw| {
- $pbr$::MapInner{ raw, arena: &self.inner.arena }
- });
- $pb$::MapView::from_inner($pbi$::Private, inner)
- }
-
- pub fn r#$field$_mut(&mut self)
- -> $pb$::MapMut<'_, $Key$, $Value$> {
- let raw = unsafe {
- $getter_thunk_mut$(self.inner.msg, self.inner.arena.raw())
- };
- let inner = $pbr$::MapInner{ raw, arena: &self.inner.arena };
- $pb$::MapMut::from_inner($pbi$::Private, inner)
- }
- )rs");
+ pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> {
+ let inner = unsafe {
+ $getter_thunk$(self.inner.msg)
+ }.map_or_else(|| unsafe {$pbr$::empty_map()}, |raw| {
+ $pbr$::MapInner{ raw, arena: &self.inner.arena }
+ });
+ $pb$::MapView::from_inner($pbi$::Private, inner)
+ })rs");
+ } else {
+ field.Emit({}, R"rs(
+ pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> {
+ let inner = $pbr$::MapInner {
+ raw: unsafe { $getter_thunk$(self.inner.msg) },
+ _phantom: std::marker::PhantomData
+ };
+ $pb$::MapView::from_inner($pbi$::Private, inner)
+ })rs");
+ }
+ }},
+ {"getter_mut",
+ [&] {
+ if (field.is_upb()) {
+ field.Emit({}, R"rs(
+ pub fn r#$field$_mut(&mut self) -> $pb$::MapMut<'_, $Key$, $Value$> {
+ let raw = unsafe {
+ $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw())
+ };
+ let inner = $pbr$::MapInner{ raw, arena: &self.inner.arena };
+ $pb$::MapMut::from_inner($pbi$::Private, inner)
+ })rs");
+ } else {
+ field.Emit({}, R"rs(
+ pub fn r#$field$_mut(&mut self) -> $pb$::MapMut<'_, $Key$, $Value$> {
+ let inner = $pbr$::MapInner {
+ raw: unsafe { $getter_mut_thunk$(self.inner.msg) },
+ _phantom: std::marker::PhantomData
+ };
+ $pb$::MapMut::from_inner($pbi$::Private, inner)
+ })rs");
}
}}},
R"rs(
$getter$
+ $getter_mut$
)rs");
}
@@ -58,16 +81,21 @@
field.Emit(
{
{"getter_thunk", Thunk(field, "get")},
- {"getter_thunk_mut", Thunk(field, "get_mut")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
- fn $getter_thunk$(raw_msg: $pbi$::RawMessage)
- -> Option<$pbi$::RawMap>;
- fn $getter_thunk_mut$(raw_msg: $pbi$::RawMessage,
- arena: $pbi$::RawArena) -> $pbi$::RawMap;
- )rs");
+ fn $getter_thunk$(raw_msg: $pbi$::RawMessage)
+ -> Option<$pbi$::RawMap>;
+ fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage,
+ arena: $pbi$::RawArena) -> $pbi$::RawMap;
+ )rs");
+ } else {
+ field.Emit({}, R"rs(
+ fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap;
+ fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap;
+ )rs");
}
}},
},
@@ -76,6 +104,30 @@
)rs");
}
+void Map::InThunkCc(Context<FieldDescriptor> field) const {
+ field.Emit(
+ {{"field", cpp::FieldName(&field.desc())},
+ {"Key", cpp::PrimitiveTypeName(
+ field.desc().message_type()->map_key()->cpp_type())},
+ {"Value", cpp::PrimitiveTypeName(
+ field.desc().message_type()->map_value()->cpp_type())},
+ {"QualifiedMsg",
+ cpp::QualifiedClassName(field.desc().containing_type())},
+ {"getter_thunk", Thunk(field, "get")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
+ {"impls",
+ [&] {
+ field.Emit(
+ R"cc(
+ const void* $getter_thunk$($QualifiedMsg$& msg) {
+ return &msg.$field$();
+ }
+ void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); }
+ )cc");
+ }}},
+ "$impls$");
+}
+
} // namespace rust
} // namespace compiler
} // namespace protobuf