pb_decode: performance optimizations for extension field handling
diff --git a/pb_common.c b/pb_common.c
index ae3e9b6..c718eb1 100644
--- a/pb_common.c
+++ b/pb_common.c
@@ -239,6 +239,37 @@
}
}
+bool pb_field_iter_find_extension(pb_field_iter_t *iter)
+{
+ if (PB_LTYPE(iter->type) == PB_LTYPE_EXTENSION)
+ {
+ return true;
+ }
+ else
+ {
+ pb_size_t start = iter->index;
+ uint32_t fieldinfo;
+
+ do
+ {
+ /* Advance iterator but don't load values yet */
+ advance_iterator(iter);
+
+ /* Do fast check for field type */
+ fieldinfo = PB_PROGMEM_READU32(iter->descriptor->field_info[iter->field_info_index]);
+
+ if (PB_LTYPE((fieldinfo >> 8) & 0xFF) == PB_LTYPE_EXTENSION)
+ {
+ return load_descriptor_values(iter);
+ }
+ } while (iter->index != start);
+
+ /* Searched all the way back to start, and found nothing. */
+ (void)load_descriptor_values(iter);
+ return false;
+ }
+}
+
static void *pb_const_cast(const void *p)
{
/* Note: this casts away const, in order to use the common field iterator
diff --git a/pb_common.h b/pb_common.h
index 47fa2c9..58aa90f 100644
--- a/pb_common.h
+++ b/pb_common.h
@@ -32,6 +32,10 @@
* Returns false if no such field exists. */
bool pb_field_iter_find(pb_field_iter_t *iter, uint32_t tag);
+/* Find a field with type PB_LTYPE_EXTENSION, or return false if not found.
+ * There can be only one extension range field per message. */
+bool pb_field_iter_find_extension(pb_field_iter_t *iter);
+
#ifdef PB_VALIDATE_UTF8
/* Validate UTF-8 text string */
bool pb_validate_utf8(const char *s);
diff --git a/pb_decode.c b/pb_decode.c
index 4244fe0..12d9da0 100644
--- a/pb_decode.c
+++ b/pb_decode.c
@@ -31,8 +31,7 @@
static bool checkreturn decode_callback_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
static bool checkreturn decode_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
static bool checkreturn default_extension_decoder(pb_istream_t *stream, pb_extension_t *extension, uint32_t tag, pb_wire_type_t wire_type);
-static bool checkreturn decode_extension(pb_istream_t *stream, uint32_t tag, pb_wire_type_t wire_type, pb_field_iter_t *iter);
-static bool checkreturn find_extension_field(pb_field_iter_t *iter);
+static bool checkreturn decode_extension(pb_istream_t *stream, uint32_t tag, pb_wire_type_t wire_type, pb_extension_t *extension);
static bool pb_message_set_to_defaults(pb_field_iter_t *iter);
static bool checkreturn pb_dec_bool(pb_istream_t *stream, const pb_field_iter_t *field);
static bool checkreturn pb_dec_varint(pb_istream_t *stream, const pb_field_iter_t *field);
@@ -831,9 +830,8 @@
/* Try to decode an unknown field as an extension field. Tries each extension
* decoder in turn, until one of them handles the field or loop ends. */
static bool checkreturn decode_extension(pb_istream_t *stream,
- uint32_t tag, pb_wire_type_t wire_type, pb_field_iter_t *iter)
+ uint32_t tag, pb_wire_type_t wire_type, pb_extension_t *extension)
{
- pb_extension_t *extension = *(pb_extension_t* const *)iter->pData;
size_t pos = stream->bytes_left;
while (extension != NULL && pos == stream->bytes_left)
@@ -853,22 +851,6 @@
return true;
}
-/* Step through the iterator until an extension field is found or until all
- * entries have been checked. There can be only one extension field per
- * message. Returns false if no extension field is found. */
-static bool checkreturn find_extension_field(pb_field_iter_t *iter)
-{
- pb_size_t start = iter->index;
-
- do {
- if (PB_LTYPE(iter->type) == PB_LTYPE_EXTENSION)
- return true;
- (void)pb_field_iter_next(iter);
- } while (iter->index != start);
-
- return false;
-}
-
/* Initialize message fields to default values, recursively */
static bool pb_field_set_to_default(pb_field_iter_t *field)
{
@@ -989,6 +971,7 @@
static bool checkreturn pb_decode_inner(pb_istream_t *stream, const pb_msgdesc_t *fields, void *dest_struct, unsigned int flags)
{
uint32_t extension_range_start = 0;
+ pb_extension_t *extensions = NULL;
/* 'fixed_count_field' and 'fixed_count_size' track position of a repeated fixed
* count field. This can only handle _one_ repeated fixed count field that
@@ -1040,25 +1023,31 @@
if (!pb_field_iter_find(&iter, tag) || PB_LTYPE(iter.type) == PB_LTYPE_EXTENSION)
{
/* No match found, check if it matches an extension. */
+ if (extension_range_start == 0)
+ {
+ if (pb_field_iter_find_extension(&iter))
+ {
+ extensions = *(pb_extension_t* const *)iter.pData;
+ extension_range_start = iter.tag;
+ }
+
+ if (!extensions)
+ {
+ extension_range_start = (uint32_t)-1;
+ }
+ }
+
if (tag >= extension_range_start)
{
- if (!find_extension_field(&iter))
- extension_range_start = (uint32_t)-1;
- else
- extension_range_start = iter.tag;
+ size_t pos = stream->bytes_left;
- if (tag >= extension_range_start)
+ if (!decode_extension(stream, tag, wire_type, extensions))
+ return false;
+
+ if (pos != stream->bytes_left)
{
- size_t pos = stream->bytes_left;
-
- if (!decode_extension(stream, tag, wire_type, &iter))
- return false;
-
- if (pos != stream->bytes_left)
- {
- /* The field was handled */
- continue;
- }
+ /* The field was handled */
+ continue;
}
}