pb_decode: performance optimizations for check_wire_type()
diff --git a/pb_decode.c b/pb_decode.c
index 2684814..8055295 100644
--- a/pb_decode.c
+++ b/pb_decode.c
@@ -24,8 +24,7 @@
static bool checkreturn buf_read(pb_istream_t *stream, pb_byte_t *buf, size_t count);
static bool checkreturn pb_decode_varint32_eof(pb_istream_t *stream, uint32_t *dest, bool *eof);
static bool checkreturn read_raw_value(pb_istream_t *stream, pb_wire_type_t wire_type, pb_byte_t *buf, size_t *size);
-static bool checkreturn check_wire_type(pb_wire_type_t wire_type, pb_field_iter_t *field);
-static bool checkreturn decode_basic_field(pb_istream_t *stream, pb_field_iter_t *field);
+static bool checkreturn decode_basic_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
static bool checkreturn decode_static_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
static bool checkreturn decode_callback_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field);
@@ -35,7 +34,6 @@
static bool pb_message_set_to_defaults(pb_field_iter_t *iter);
static bool checkreturn pb_dec_bool(pb_istream_t *stream, const pb_field_iter_t *field);
static bool checkreturn pb_dec_varint(pb_istream_t *stream, const pb_field_iter_t *field);
-static bool checkreturn pb_dec_fixed(pb_istream_t *stream, const pb_field_iter_t *field);
static bool checkreturn pb_dec_bytes(pb_istream_t *stream, const pb_field_iter_t *field);
static bool checkreturn pb_dec_string(pb_istream_t *stream, const pb_field_iter_t *field);
static bool checkreturn pb_dec_submessage(pb_istream_t *stream, const pb_field_iter_t *field);
@@ -58,6 +56,8 @@
#define pb_uint64_t uint64_t
#endif
+#define PB_WT_PACKED ((pb_wire_type_t)0xFF)
+
typedef struct {
uint32_t bitfield[(PB_MAX_REQUIRED_FIELDS + 31) / 32];
} pb_fields_seen_t;
@@ -387,61 +387,70 @@
* Decode a single field *
*************************/
-static bool checkreturn check_wire_type(pb_wire_type_t wire_type, pb_field_iter_t *field)
+static bool checkreturn decode_basic_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *field)
{
switch (PB_LTYPE(field->type))
{
case PB_LTYPE_BOOL:
- case PB_LTYPE_VARINT:
- case PB_LTYPE_UVARINT:
- case PB_LTYPE_SVARINT:
- return wire_type == PB_WT_VARINT;
+ if (wire_type != PB_WT_VARINT && wire_type != PB_WT_PACKED)
+ PB_RETURN_ERROR(stream, "wrong wire type");
- case PB_LTYPE_FIXED32:
- return wire_type == PB_WT_32BIT;
-
- case PB_LTYPE_FIXED64:
- return wire_type == PB_WT_64BIT;
-
- case PB_LTYPE_BYTES:
- case PB_LTYPE_STRING:
- case PB_LTYPE_SUBMESSAGE:
- case PB_LTYPE_SUBMSG_W_CB:
- case PB_LTYPE_FIXED_LENGTH_BYTES:
- return wire_type == PB_WT_STRING;
-
- default:
- return false;
- }
-}
-
-static bool checkreturn decode_basic_field(pb_istream_t *stream, pb_field_iter_t *field)
-{
- switch (PB_LTYPE(field->type))
- {
- case PB_LTYPE_BOOL:
return pb_dec_bool(stream, field);
case PB_LTYPE_VARINT:
case PB_LTYPE_UVARINT:
case PB_LTYPE_SVARINT:
+ if (wire_type != PB_WT_VARINT && wire_type != PB_WT_PACKED)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
return pb_dec_varint(stream, field);
case PB_LTYPE_FIXED32:
+ if (wire_type != PB_WT_32BIT && wire_type != PB_WT_PACKED)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
+ return pb_decode_fixed32(stream, field->pData);
+
case PB_LTYPE_FIXED64:
- return pb_dec_fixed(stream, field);
+ if (wire_type != PB_WT_64BIT && wire_type != PB_WT_PACKED)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
+#ifdef PB_CONVERT_DOUBLE_FLOAT
+ if (field->data_size == sizeof(float))
+ {
+ return pb_decode_double_as_float(stream, (float*)field->pData);
+ }
+#endif
+
+#ifdef PB_WITHOUT_64BIT
+ PB_RETURN_ERROR(stream, "invalid data_size");
+#else
+ return pb_decode_fixed64(stream, field->pData);
+#endif
case PB_LTYPE_BYTES:
+ if (wire_type != PB_WT_STRING)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
return pb_dec_bytes(stream, field);
case PB_LTYPE_STRING:
+ if (wire_type != PB_WT_STRING)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
return pb_dec_string(stream, field);
case PB_LTYPE_SUBMESSAGE:
case PB_LTYPE_SUBMSG_W_CB:
+ if (wire_type != PB_WT_STRING)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
return pb_dec_submessage(stream, field);
case PB_LTYPE_FIXED_LENGTH_BYTES:
+ if (wire_type != PB_WT_STRING)
+ PB_RETURN_ERROR(stream, "wrong wire type");
+
return pb_dec_fixed_length_bytes(stream, field);
default:
@@ -454,18 +463,12 @@
switch (PB_HTYPE(field->type))
{
case PB_HTYPE_REQUIRED:
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
case PB_HTYPE_OPTIONAL:
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
if (field->pSize != NULL)
*(bool*)field->pSize = true;
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
case PB_HTYPE_REPEATED:
if (wire_type == PB_WT_STRING
@@ -482,7 +485,7 @@
while (substream.bytes_left > 0 && *size < field->array_size)
{
- if (!decode_basic_field(&substream, field))
+ if (!decode_basic_field(&substream, PB_WT_PACKED, field))
{
status = false;
break;
@@ -504,13 +507,10 @@
pb_size_t *size = (pb_size_t*)field->pSize;
field->pData = (char*)field->pField + field->data_size * (*size);
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
if ((*size)++ >= field->array_size)
PB_RETURN_ERROR(stream, "array overflow");
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
}
case PB_HTYPE_ONEOF:
@@ -527,10 +527,7 @@
memset(field->pData, 0, (size_t)field->data_size);
}
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
default:
PB_RETURN_ERROR(stream, "invalid field type");
@@ -616,9 +613,6 @@
case PB_HTYPE_REQUIRED:
case PB_HTYPE_OPTIONAL:
case PB_HTYPE_ONEOF:
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
if (PB_LTYPE_IS_SUBMSG(field->type) && *(void**)field->pField != NULL)
{
/* Duplicate field, have to release the old allocation first. */
@@ -636,7 +630,7 @@
{
/* pb_dec_string and pb_dec_bytes handle allocation themselves */
field->pData = field->pField;
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
}
else
{
@@ -645,7 +639,7 @@
field->pData = *(void**)field->pField;
initialize_pointer_field(field->pData, field);
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
}
case PB_HTYPE_REPEATED:
@@ -693,7 +687,7 @@
/* Decode the array entry */
field->pData = *(char**)field->pField + field->data_size * (*size);
initialize_pointer_field(field->pData, field);
- if (!decode_basic_field(&substream, field))
+ if (!decode_basic_field(&substream, PB_WT_PACKED, field))
{
status = false;
break;
@@ -714,16 +708,13 @@
if (*size == PB_SIZE_MAX)
PB_RETURN_ERROR(stream, "too many array entries");
- if (!check_wire_type(wire_type, field))
- PB_RETURN_ERROR(stream, "wrong wire type");
-
if (!allocate_field(stream, field->pField, field->data_size, (size_t)(*size + 1)))
return false;
field->pData = *(char**)field->pField + field->data_size * (*size);
(*size)++;
initialize_pointer_field(field->pData, field);
- return decode_basic_field(stream, field);
+ return decode_basic_field(stream, wire_type, field);
}
default:
@@ -1442,7 +1433,7 @@
/* See issue 97: Google's C++ protobuf allows negative varint values to
* be cast as int32_t, instead of the int64_t that should be used when
- * encoding. Previous nanopb versions had a bug in encoding. In order to
+ * encoding. Nanopb versions before 0.2.5 had a bug in encoding. In order to
* not break decoding of such messages, we cast <=32 bit fields to
* int32_t first to get the sign correct.
*/
@@ -1471,31 +1462,6 @@
}
}
-static bool checkreturn pb_dec_fixed(pb_istream_t *stream, const pb_field_iter_t *field)
-{
-#ifdef PB_CONVERT_DOUBLE_FLOAT
- if (field->data_size == sizeof(float) && PB_LTYPE(field->type) == PB_LTYPE_FIXED64)
- {
- return pb_decode_double_as_float(stream, (float*)field->pData);
- }
-#endif
-
- if (field->data_size == sizeof(uint32_t))
- {
- return pb_decode_fixed32(stream, field->pData);
- }
-#ifndef PB_WITHOUT_64BIT
- else if (field->data_size == sizeof(uint64_t))
- {
- return pb_decode_fixed64(stream, field->pData);
- }
-#endif
- else
- {
- PB_RETURN_ERROR(stream, "invalid data_size");
- }
-}
-
static bool checkreturn pb_dec_bytes(pb_istream_t *stream, const pb_field_iter_t *field)
{
uint32_t size;
diff --git a/tests/decode_unittests/decode_unittests.c b/tests/decode_unittests/decode_unittests.c
index f177535..80b87b5 100644
--- a/tests/decode_unittests/decode_unittests.c
+++ b/tests/decode_unittests/decode_unittests.c
@@ -237,34 +237,25 @@
{
pb_istream_t s;
- pb_field_iter_t f;
float d;
- f.type = PB_LTYPE_FIXED32;
- f.data_size = sizeof(d);
- f.pData = &d;
-
COMMENT("Test pb_dec_fixed using float (failures here may be caused by imperfect rounding)")
- TEST((s = S("\x00\x00\x00\x00"), pb_dec_fixed(&s, &f) && d == 0.0f))
- TEST((s = S("\x00\x00\xc6\x42"), pb_dec_fixed(&s, &f) && d == 99.0f))
- TEST((s = S("\x4e\x61\x3c\xcb"), pb_dec_fixed(&s, &f) && d == -12345678.0f))
+ TEST((s = S("\x00\x00\x00\x00"), pb_decode_fixed32(&s, &d) && d == 0.0f))
+ TEST((s = S("\x00\x00\xc6\x42"), pb_decode_fixed32(&s, &d) && d == 99.0f))
+ TEST((s = S("\x4e\x61\x3c\xcb"), pb_decode_fixed32(&s, &d) && d == -12345678.0f))
d = -12345678.0f;
- TEST((s = S("\x00"), !pb_dec_fixed(&s, &f) && d == -12345678.0f))
+ TEST((s = S("\x00"), !pb_decode_fixed32(&s, &d) && d == -12345678.0f))
}
+ if (sizeof(double) == 8)
{
pb_istream_t s;
- pb_field_iter_t f;
double d;
- f.type = PB_LTYPE_FIXED64;
- f.data_size = sizeof(d);
- f.pData = &d;
-
COMMENT("Test pb_dec_fixed64 using double (failures here may be caused by imperfect rounding)")
- TEST((s = S("\x00\x00\x00\x00\x00\x00\x00\x00"), pb_dec_fixed(&s, &f) && d == 0.0))
- TEST((s = S("\x00\x00\x00\x00\x00\xc0\x58\x40"), pb_dec_fixed(&s, &f) && d == 99.0))
- TEST((s = S("\x00\x00\x00\xc0\x29\x8c\x67\xc1"), pb_dec_fixed(&s, &f) && d == -12345678.0f))
+ TEST((s = S("\x00\x00\x00\x00\x00\x00\x00\x00"), pb_decode_fixed64(&s, &d) && d == 0.0))
+ TEST((s = S("\x00\x00\x00\x00\x00\xc0\x58\x40"), pb_decode_fixed64(&s, &d) && d == 99.0))
+ TEST((s = S("\x00\x00\x00\xc0\x29\x8c\x67\xc1"), pb_decode_fixed64(&s, &d) && d == -12345678.0f))
}
{