pb_decode: performance optimizations for extension field handling
diff --git a/pb_common.c b/pb_common.c
index ae3e9b6..c718eb1 100644
--- a/pb_common.c
+++ b/pb_common.c
@@ -239,6 +239,37 @@
     }
 }
 
+bool pb_field_iter_find_extension(pb_field_iter_t *iter)
+{
+    if (PB_LTYPE(iter->type) == PB_LTYPE_EXTENSION)
+    {
+        return true;
+    }
+    else
+    {
+        pb_size_t start = iter->index;
+        uint32_t fieldinfo;
+
+        do
+        {
+            /* Advance iterator but don't load values yet */
+            advance_iterator(iter);
+
+            /* Do fast check for field type */
+            fieldinfo = PB_PROGMEM_READU32(iter->descriptor->field_info[iter->field_info_index]);
+
+            if (PB_LTYPE((fieldinfo >> 8) & 0xFF) == PB_LTYPE_EXTENSION)
+            {
+                return load_descriptor_values(iter);
+            }
+        } while (iter->index != start);
+
+        /* Searched all the way back to start, and found nothing. */
+        (void)load_descriptor_values(iter);
+        return false;
+    }
+}
+
 static void *pb_const_cast(const void *p)
 {
     /* Note: this casts away const, in order to use the common field iterator
diff --git a/pb_common.h b/pb_common.h
index 47fa2c9..58aa90f 100644
--- a/pb_common.h
+++ b/pb_common.h
@@ -32,6 +32,10 @@
  * Returns false if no such field exists. */
 bool pb_field_iter_find(pb_field_iter_t *iter, uint32_t tag);
 
+/* Find a field with type PB_LTYPE_EXTENSION, or return false if not found.
+ * There can be only one extension range field per message. */
+bool pb_field_iter_find_extension(pb_field_iter_t *iter);
+
 #ifdef PB_VALIDATE_UTF8
 /* Validate UTF-8 text string */
 bool pb_validate_utf8(const char *s);
diff --git a/pb_decode.c b/pb_decode.c
index 4244fe0..12d9da0 100644
--- a/pb_decode.c
+++ b/pb_decode.c
@@ -31,8 +31,7 @@
 static bool checkreturn decode_callback_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
 static bool checkreturn decode_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
 static bool checkreturn default_extension_decoder(pb_istream_t *stream, pb_extension_t *extension, uint32_t tag, pb_wire_type_t wire_type);
-static bool checkreturn decode_extension(pb_istream_t *stream, uint32_t tag, pb_wire_type_t wire_type, pb_field_iter_t *iter);
-static bool checkreturn find_extension_field(pb_field_iter_t *iter);
+static bool checkreturn decode_extension(pb_istream_t *stream, uint32_t tag, pb_wire_type_t wire_type, pb_extension_t *extension);
 static bool pb_message_set_to_defaults(pb_field_iter_t *iter);
 static bool checkreturn pb_dec_bool(pb_istream_t *stream, const pb_field_iter_t *field);
 static bool checkreturn pb_dec_varint(pb_istream_t *stream, const pb_field_iter_t *field);
@@ -831,9 +830,8 @@
 /* Try to decode an unknown field as an extension field. Tries each extension
  * decoder in turn, until one of them handles the field or loop ends. */
 static bool checkreturn decode_extension(pb_istream_t *stream,
-    uint32_t tag, pb_wire_type_t wire_type, pb_field_iter_t *iter)
+    uint32_t tag, pb_wire_type_t wire_type, pb_extension_t *extension)
 {
-    pb_extension_t *extension = *(pb_extension_t* const *)iter->pData;
     size_t pos = stream->bytes_left;
     
     while (extension != NULL && pos == stream->bytes_left)
@@ -853,22 +851,6 @@
     return true;
 }
 
-/* Step through the iterator until an extension field is found or until all
- * entries have been checked. There can be only one extension field per
- * message. Returns false if no extension field is found. */
-static bool checkreturn find_extension_field(pb_field_iter_t *iter)
-{
-    pb_size_t start = iter->index;
-
-    do {
-        if (PB_LTYPE(iter->type) == PB_LTYPE_EXTENSION)
-            return true;
-        (void)pb_field_iter_next(iter);
-    } while (iter->index != start);
-    
-    return false;
-}
-
 /* Initialize message fields to default values, recursively */
 static bool pb_field_set_to_default(pb_field_iter_t *field)
 {
@@ -989,6 +971,7 @@
 static bool checkreturn pb_decode_inner(pb_istream_t *stream, const pb_msgdesc_t *fields, void *dest_struct, unsigned int flags)
 {
     uint32_t extension_range_start = 0;
+    pb_extension_t *extensions = NULL;
 
     /* 'fixed_count_field' and 'fixed_count_size' track position of a repeated fixed
      * count field. This can only handle _one_ repeated fixed count field that
@@ -1040,25 +1023,31 @@
         if (!pb_field_iter_find(&iter, tag) || PB_LTYPE(iter.type) == PB_LTYPE_EXTENSION)
         {
             /* No match found, check if it matches an extension. */
+            if (extension_range_start == 0)
+            {
+                if (pb_field_iter_find_extension(&iter))
+                {
+                    extensions = *(pb_extension_t* const *)iter.pData;
+                    extension_range_start = iter.tag;
+                }
+
+                if (!extensions)
+                {
+                    extension_range_start = (uint32_t)-1;
+                }
+            }
+
             if (tag >= extension_range_start)
             {
-                if (!find_extension_field(&iter))
-                    extension_range_start = (uint32_t)-1;
-                else
-                    extension_range_start = iter.tag;
+                size_t pos = stream->bytes_left;
 
-                if (tag >= extension_range_start)
+                if (!decode_extension(stream, tag, wire_type, extensions))
+                    return false;
+
+                if (pos != stream->bytes_left)
                 {
-                    size_t pos = stream->bytes_left;
-
-                    if (!decode_extension(stream, tag, wire_type, &iter))
-                        return false;
-
-                    if (pos != stream->bytes_left)
-                    {
-                        /* The field was handled */
-                        continue;
-                    }
+                    /* The field was handled */
+                    continue;
                 }
             }