Nextgen Proto Pythonic API: Add 'in' operator
(Second attempt. The first attempt missed ListValue)
The “in” operator will be consistent with HasField but a little different with Proto Plus.
The detail behavior of “in” operator in Nextgen
* For WKT Struct (to be consist with old Struct behavior):
-Raise TypeError if not pass a string
-Check if the key is in the struct.fields
* For WKT ListValue (to be consist with old behavior):
-Check if the key is in the list_value.values
* For other messages:
-Raise ValueError if not pass a string
-Raise ValueError if the string is not a field
-For Oneof: Check any field under the oneof is set
-For has-presence field: check if set
-For non-has-presence field (include repeated fields): raise ValueError
PiperOrigin-RevId: 631143378
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 7b4478e..e42538e 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1336,6 +1336,24 @@
union.DESCRIPTOR, message_module.TestAllTypes.NestedEnum.DESCRIPTOR
)
+ def testIn(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertNotIn('optional_nested_message', m)
+ self.assertNotIn('oneof_bytes', m)
+ self.assertNotIn('oneof_string', m)
+ with self.assertRaises(ValueError) as e:
+ 'repeated_int32' in m
+ with self.assertRaises(ValueError) as e:
+ 'repeated_nested_message' in m
+ with self.assertRaises(ValueError) as e:
+ 1 in m
+ with self.assertRaises(ValueError) as e:
+ 'not_a_field' in m
+ test_util.SetAllFields(m)
+ self.assertIn('optional_nested_message', m)
+ self.assertIn('oneof_bytes', m)
+ self.assertNotIn('oneof_string', m)
+
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
@@ -1367,6 +1385,9 @@
self.assertTrue(message.HasField('optional_int32'))
self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField('optional_nested_message'))
+ self.assertIn('optional_int32', message)
+ self.assertIn('optional_bool', message)
+ self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5
@@ -1385,6 +1406,9 @@
self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField('optional_nested_message'))
+ self.assertNotIn('optional_int32', message)
+ self.assertNotIn('optional_bool', message)
+ self.assertNotIn('optional_nested_message', message)
self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
@@ -1711,6 +1735,12 @@
with self.assertRaises(ValueError):
message.HasField('repeated_nested_message')
+ # Can not test "in" operator.
+ with self.assertRaises(ValueError):
+ 'repeated_int32' in message
+ with self.assertRaises(ValueError):
+ 'repeated_nested_message' in message
+
# Fields should default to their type-specific default.
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
@@ -1721,6 +1751,7 @@
# Setting a submessage should still return proper presence information.
message.optional_nested_message.bb = 0
self.assertTrue(message.HasField('optional_nested_message'))
+ self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 34efbfe..869f2aa 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1000,6 +1000,21 @@
cls.__unicode__ = __unicode__
+def _AddContainsMethod(message_descriptor, cls):
+
+ if message_descriptor.full_name == 'google.protobuf.Struct':
+ def __contains__(self, key):
+ return key in self.fields
+ elif message_descriptor.full_name == 'google.protobuf.ListValue':
+ def __contains__(self, value):
+ return value in self.items()
+ else:
+ def __contains__(self, field):
+ return self.HasField(field)
+
+ cls.__contains__ = __contains__
+
+
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@@ -1394,6 +1409,7 @@
_AddStrMethod(message_descriptor, cls)
_AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
+ _AddContainsMethod(message_descriptor, cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index 8a0a3b1..ea90be3 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -497,9 +497,6 @@
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
- def __contains__(self, item):
- return item in self.fields
-
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 498d3f2..5927766 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -515,7 +515,6 @@
self.assertEqual(False, struct_list[3])
self.assertEqual(None, struct_list[4])
self.assertEqual(inner_struct, struct_list[5])
- self.assertIn(6, struct_list)
struct_list[1] = 7
self.assertEqual(7, struct_list[1])
@@ -570,6 +569,36 @@
self.assertEqual([6, True, False, None, inner_struct],
list(struct['key5'].items()))
+ def testInOperator(self):
+ # in operator for Struct
+ struct = struct_pb2.Struct()
+ struct['key'] = 5
+
+ self.assertIn('key', struct)
+ self.assertNotIn('fields', struct)
+ with self.assertRaises(TypeError) as e:
+ 1 in struct
+
+ # in operator for ListValue
+ struct_list = struct.get_or_create_list('key2')
+ self.assertIsInstance(struct_list, collections_abc.Sequence)
+ struct_list.extend([6, 'seven', True, False, None])
+ struct_list.add_struct()['subkey'] = 9
+ inner_struct = struct.__class__()
+ inner_struct['subkey'] = 9
+
+ self.assertIn(6, struct_list)
+ self.assertIn('seven', struct_list)
+ self.assertIn(True, struct_list)
+ self.assertIn(False, struct_list)
+ self.assertIn(None, struct_list)
+ self.assertIn(inner_struct, struct_list)
+ self.assertNotIn('values', struct_list)
+ self.assertNotIn(10, struct_list)
+
+ for item in struct_list:
+ self.assertIn(item, struct_list)
+
def testStructAssignment(self):
# Tests struct assignment from another struct
s1 = struct_pb2.Struct()
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index 3226b6e..ae9cb14 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -75,6 +75,34 @@
"""Outputs a human-readable representation of the message."""
raise NotImplementedError
+ def __contains__(self, field_name_or_key):
+ """Checks if a certain field is set for the message.
+
+ Has presence fields return true if the field is set, false if the field is
+ not set. Fields without presence do raise `ValueError` (this includes
+ repeated fields, map fields, and implicit presence fields).
+
+ If field_name is not defined in the message descriptor, `ValueError` will
+ be raised.
+ Note: WKT Struct checks if the key is contained in fields. ListValue checks
+ if the item is contained in the list.
+
+ Args:
+ field_name_or_key: For Struct, the key (str) of the fields map. For
+ ListValue, any type that may be contained in the list. For other
+ messages, name of the field (str) to check for presence.
+
+ Returns:
+ bool: For Struct, whether the item is contained in fields. For ListValue,
+ whether the item is contained in the list. For other message,
+ whether a value has been set for the named field.
+
+ Raises:
+ ValueError: For normal messages, if the `field_name_or_key` is not a
+ member of this message or `field_name_or_key` is not a string.
+ """
+ raise NotImplementedError
+
def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 06f9515..39fe35a 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -10,6 +10,7 @@
#include "google/protobuf/pyext/message.h"
+#include <Python.h>
#include <structmember.h> // A Python header file.
#include <cstdint>
@@ -36,6 +37,7 @@
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
@@ -85,6 +87,12 @@
return reflection->IsLazyField(field) ||
reflection->IsLazyExtension(message, field);
}
+ static bool ContainsMapKey(const Reflection* reflection,
+ const Message& message,
+ const FieldDescriptor* field,
+ const MapKey& map_key) {
+ return reflection->ContainsMapKey(message, field, map_key);
+ }
};
static PyObject* kDESCRIPTOR;
@@ -1293,11 +1301,16 @@
char* field_name;
Py_ssize_t size;
field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
+ Message* message = self->message;
+
if (!field_name) {
+ PyErr_Format(PyExc_ValueError,
+ "The field name passed to message %s"
+ " is not a str.",
+ message->GetDescriptor()->name().c_str());
return nullptr;
}
- Message* message = self->message;
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
message, absl::string_view(field_name, size), &is_in_oneof);
@@ -2290,6 +2303,48 @@
return decoded;
}
+PyObject* Contains(CMessage* self, PyObject* arg) {
+ Message* message = self->message;
+ const Descriptor* descriptor = message->GetDescriptor();
+ switch (descriptor->well_known_type()) {
+ case Descriptor::WELLKNOWNTYPE_STRUCT: {
+ // For WKT Struct, check if the key is in the fields.
+ const Reflection* reflection = message->GetReflection();
+ const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
+ const FieldDescriptor* key_field = map_field->message_type()->map_key();
+ PyObject* py_string = CheckString(arg, key_field);
+ if (!py_string) {
+ PyErr_SetString(PyExc_TypeError,
+ "The key passed to Struct message must be a str.");
+ return nullptr;
+ }
+ char* value;
+ Py_ssize_t value_len;
+ if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
+ Py_DECREF(py_string);
+ Py_RETURN_FALSE;
+ }
+ std::string key_str;
+ key_str.assign(value, value_len);
+ Py_DECREF(py_string);
+
+ MapKey map_key;
+ map_key.SetStringValue(key_str);
+ return PyBool_FromLong(MessageReflectionFriend::ContainsMapKey(
+ reflection, *message, map_field, map_key));
+ }
+ case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
+ // For WKT ListValue, check if the key is in the items.
+ PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
+ "items", nullptr);
+ return PyBool_FromLong(PySequence_Contains(items, arg));
+ }
+ default:
+ // For other messages, check with HasField.
+ return HasField(self, arg);
+ }
+}
+
// CMessage static methods:
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
PyObject* unused_arg) {
@@ -2338,6 +2393,8 @@
"Makes a deep copy of the class."},
{"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
"Outputs a unicode representation of the message."},
+ {"__contains__", (PyCFunction)Contains, METH_O,
+ "Checks if a message field is set."},
{"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},
diff --git a/python/message.c b/python/message.c
index c0c0882..d5213ab 100644
--- a/python/message.c
+++ b/python/message.c
@@ -1044,6 +1044,35 @@
NULL);
}
+static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
+ const upb_MessageDef* msgdef = PyUpb_Message_GetMsgdef(_self);
+ switch (upb_MessageDef_WellKnownType(msgdef)) {
+ case kUpb_WellKnown_Struct: {
+ // For WKT Struct, check if the key is in the fields.
+ PyUpb_Message* self = (void*)_self;
+ if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
+ upb_Message* msg = PyUpb_Message_GetMsg(self);
+ const upb_FieldDef* f = upb_MessageDef_FindFieldByName(msgdef, "fields");
+ const upb_Map* map = upb_Message_GetFieldByDef(msg, f).map_val;
+ const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
+ const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
+ upb_MessageValue u_key;
+ if (!PyUpb_PyToUpb(arg, key_f, &u_key, NULL)) return NULL;
+ return PyBool_FromLong(upb_Map_Get(map, u_key, NULL));
+ }
+ case kUpb_WellKnown_ListValue: {
+ // For WKT ListValue, check if the key is in the items.
+ PyUpb_Message* self = (void*)_self;
+ if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
+ PyObject* items = PyObject_CallMethod(_self, "items", NULL);
+ return PyBool_FromLong(PySequence_Contains(items, arg));
+ }
+ default:
+ // For other messages, check with HasField.
+ return PyUpb_Message_HasField(_self, arg);
+ }
+}
+
static PyObject* PyUpb_Message_FindInitializationErrors(PyObject* _self,
PyObject* arg);
@@ -1642,6 +1671,8 @@
// TODO
//{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
// "Outputs a unicode representation of the message." },
+ {"__contains__", PyUpb_Message_Contains, METH_O,
+ "Checks if a message field is set."},
{"ByteSize", (PyCFunction)PyUpb_Message_ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)PyUpb_Message_Clear, METH_NOARGS,