blob: 34e1f9bd291a354577b29c3a0d2bfbfb09cfb002 [file] [log] [blame]
Adam Cozzette501ecec2023-09-26 14:36:20 -07001// Protocol Buffers - Google's data interchange format
2// Copyright 2023 Google LLC. All rights reserved.
Adam Cozzette501ecec2023-09-26 14:36:20 -07003//
Protobuf Team Bot0fab7732023-11-20 13:38:15 -08004// Use of this source code is governed by a BSD-style
5// license that can be found in the LICENSE file or at
6// https://developers.google.com/open-source/licenses/bsd
Adam Cozzette501ecec2023-09-26 14:36:20 -07007
8#include "python/extension_dict.h"
9
10#include "python/message.h"
11#include "python/protobuf.h"
12#include "upb/reflection/def.h"
13
14// -----------------------------------------------------------------------------
15// ExtensionDict
16// -----------------------------------------------------------------------------
17
18typedef struct {
19 PyObject_HEAD;
20 PyObject* msg; // Owning ref to our parent pessage.
21} PyUpb_ExtensionDict;
22
23PyObject* PyUpb_ExtensionDict_New(PyObject* msg) {
24 PyUpb_ModuleState* state = PyUpb_ModuleState_Get();
25 PyUpb_ExtensionDict* ext_dict =
26 (void*)PyType_GenericAlloc(state->extension_dict_type, 0);
27 ext_dict->msg = msg;
28 Py_INCREF(ext_dict->msg);
29 return &ext_dict->ob_base;
30}
31
32static PyObject* PyUpb_ExtensionDict_FindExtensionByName(PyObject* _self,
33 PyObject* key) {
34 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
35 const char* name = PyUpb_GetStrData(key);
36 if (!name) {
37 PyErr_Format(PyExc_TypeError, "_FindExtensionByName expect a str");
38 return NULL;
39 }
40 const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
41 const upb_FileDef* file = upb_MessageDef_File(m);
42 const upb_DefPool* symtab = upb_FileDef_Pool(file);
43 const upb_FieldDef* ext = upb_DefPool_FindExtensionByName(symtab, name);
44 if (ext) {
45 return PyUpb_FieldDescriptor_Get(ext);
46 } else {
47 Py_RETURN_NONE;
48 }
49}
50
51static PyObject* PyUpb_ExtensionDict_FindExtensionByNumber(PyObject* _self,
52 PyObject* arg) {
53 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
54 const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
55 const upb_MiniTable* l = upb_MessageDef_MiniTable(m);
56 const upb_FileDef* file = upb_MessageDef_File(m);
57 const upb_DefPool* symtab = upb_FileDef_Pool(file);
58 const upb_ExtensionRegistry* reg = upb_DefPool_ExtensionRegistry(symtab);
59 int64_t number = PyLong_AsLong(arg);
60 if (number == -1 && PyErr_Occurred()) return NULL;
61 const upb_MiniTableExtension* ext =
62 (upb_MiniTableExtension*)upb_ExtensionRegistry_Lookup(reg, l, number);
63 if (ext) {
64 const upb_FieldDef* f = upb_DefPool_FindExtensionByMiniTable(symtab, ext);
65 return PyUpb_FieldDescriptor_Get(f);
66 } else {
67 Py_RETURN_NONE;
68 }
69}
70
71static void PyUpb_ExtensionDict_Dealloc(PyUpb_ExtensionDict* self) {
72 PyUpb_Message_ClearExtensionDict(self->msg);
73 Py_DECREF(self->msg);
74 PyUpb_Dealloc(self);
75}
76
77static PyObject* PyUpb_ExtensionDict_RichCompare(PyObject* _self,
78 PyObject* _other, int opid) {
79 // Only equality comparisons are implemented.
80 if (opid != Py_EQ && opid != Py_NE) {
81 Py_INCREF(Py_NotImplemented);
82 return Py_NotImplemented;
83 }
84 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
85 bool equals = false;
86 if (PyObject_TypeCheck(_other, Py_TYPE(_self))) {
87 PyUpb_ExtensionDict* other = (PyUpb_ExtensionDict*)_other;
88 equals = self->msg == other->msg;
89 }
90 bool ret = opid == Py_EQ ? equals : !equals;
91 return PyBool_FromLong(ret);
92}
93
94static int PyUpb_ExtensionDict_Contains(PyObject* _self, PyObject* key) {
95 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
96 const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
97 if (!f) return -1;
98 upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
99 if (!msg) return 0;
100 if (upb_FieldDef_IsRepeated(f)) {
101 upb_MessageValue val = upb_Message_GetFieldByDef(msg, f);
102 return upb_Array_Size(val.array_val) > 0;
103 } else {
104 return upb_Message_HasFieldByDef(msg, f);
105 }
106}
107
108static Py_ssize_t PyUpb_ExtensionDict_Length(PyObject* _self) {
109 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
110 upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
111 return msg ? upb_Message_ExtensionCount(msg) : 0;
112}
113
114static PyObject* PyUpb_ExtensionDict_Subscript(PyObject* _self, PyObject* key) {
115 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
116 const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
117 if (!f) return NULL;
118 return PyUpb_Message_GetFieldValue(self->msg, f);
119}
120
121static int PyUpb_ExtensionDict_AssignSubscript(PyObject* _self, PyObject* key,
122 PyObject* val) {
123 PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
124 const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
125 if (!f) return -1;
126 if (val) {
127 return PyUpb_Message_SetFieldValue(self->msg, f, val, PyExc_TypeError);
128 } else {
129 PyUpb_Message_DoClearField(self->msg, f);
130 return 0;
131 }
132}
133
134static PyObject* PyUpb_ExtensionIterator_New(PyObject* _ext_dict);
135
136static PyMethodDef PyUpb_ExtensionDict_Methods[] = {
137 {"_FindExtensionByName", PyUpb_ExtensionDict_FindExtensionByName, METH_O,
138 "Finds an extension by name."},
139 {"_FindExtensionByNumber", PyUpb_ExtensionDict_FindExtensionByNumber,
140 METH_O, "Finds an extension by number."},
141 {NULL, NULL},
142};
143
144static PyType_Slot PyUpb_ExtensionDict_Slots[] = {
145 {Py_tp_dealloc, PyUpb_ExtensionDict_Dealloc},
146 {Py_tp_methods, PyUpb_ExtensionDict_Methods},
147 //{Py_tp_getset, PyUpb_ExtensionDict_Getters},
148 //{Py_tp_hash, PyObject_HashNotImplemented},
149 {Py_tp_richcompare, PyUpb_ExtensionDict_RichCompare},
150 {Py_tp_iter, PyUpb_ExtensionIterator_New},
151 {Py_sq_contains, PyUpb_ExtensionDict_Contains},
152 {Py_sq_length, PyUpb_ExtensionDict_Length},
153 {Py_mp_length, PyUpb_ExtensionDict_Length},
154 {Py_mp_subscript, PyUpb_ExtensionDict_Subscript},
155 {Py_mp_ass_subscript, PyUpb_ExtensionDict_AssignSubscript},
156 {0, NULL}};
157
158static PyType_Spec PyUpb_ExtensionDict_Spec = {
159 PYUPB_MODULE_NAME ".ExtensionDict", // tp_name
160 sizeof(PyUpb_ExtensionDict), // tp_basicsize
161 0, // tp_itemsize
162 Py_TPFLAGS_DEFAULT, // tp_flags
163 PyUpb_ExtensionDict_Slots,
164};
165
166// -----------------------------------------------------------------------------
167// ExtensionIterator
168// -----------------------------------------------------------------------------
169
170typedef struct {
171 PyObject_HEAD;
172 PyObject* msg;
173 size_t iter;
174} PyUpb_ExtensionIterator;
175
176static PyObject* PyUpb_ExtensionIterator_New(PyObject* _ext_dict) {
177 PyUpb_ExtensionDict* ext_dict = (PyUpb_ExtensionDict*)_ext_dict;
178 PyUpb_ModuleState* state = PyUpb_ModuleState_Get();
179 PyUpb_ExtensionIterator* iter =
180 (void*)PyType_GenericAlloc(state->extension_iterator_type, 0);
181 if (!iter) return NULL;
182 iter->msg = ext_dict->msg;
183 iter->iter = kUpb_Message_Begin;
184 Py_INCREF(iter->msg);
185 return &iter->ob_base;
186}
187
188static void PyUpb_ExtensionIterator_Dealloc(void* _self) {
189 PyUpb_ExtensionIterator* self = (PyUpb_ExtensionIterator*)_self;
190 Py_DECREF(self->msg);
191 PyUpb_Dealloc(_self);
192}
193
194PyObject* PyUpb_ExtensionIterator_IterNext(PyObject* _self) {
195 PyUpb_ExtensionIterator* self = (PyUpb_ExtensionIterator*)_self;
196 upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
197 if (!msg) return NULL;
198 const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
199 const upb_DefPool* symtab = upb_FileDef_Pool(upb_MessageDef_File(m));
200 while (true) {
201 const upb_FieldDef* f;
202 upb_MessageValue val;
203 if (!upb_Message_Next(msg, m, symtab, &f, &val, &self->iter)) return NULL;
204 if (upb_FieldDef_IsExtension(f)) return PyUpb_FieldDescriptor_Get(f);
205 }
206}
207
208static PyType_Slot PyUpb_ExtensionIterator_Slots[] = {
209 {Py_tp_dealloc, PyUpb_ExtensionIterator_Dealloc},
210 {Py_tp_iter, PyObject_SelfIter},
211 {Py_tp_iternext, PyUpb_ExtensionIterator_IterNext},
212 {0, NULL}};
213
214static PyType_Spec PyUpb_ExtensionIterator_Spec = {
215 PYUPB_MODULE_NAME ".ExtensionIterator", // tp_name
216 sizeof(PyUpb_ExtensionIterator), // tp_basicsize
217 0, // tp_itemsize
218 Py_TPFLAGS_DEFAULT, // tp_flags
219 PyUpb_ExtensionIterator_Slots,
220};
221
222// -----------------------------------------------------------------------------
223// Top Level
224// -----------------------------------------------------------------------------
225
226bool PyUpb_InitExtensionDict(PyObject* m) {
227 PyUpb_ModuleState* s = PyUpb_ModuleState_GetFromModule(m);
228
229 s->extension_dict_type = PyUpb_AddClass(m, &PyUpb_ExtensionDict_Spec);
230 s->extension_iterator_type = PyUpb_AddClass(m, &PyUpb_ExtensionIterator_Spec);
231
232 return s->extension_dict_type && s->extension_iterator_type;
233}