Test also stream input in fuzz test.
diff --git a/tests/fuzztest/SConscript b/tests/fuzztest/SConscript
index f07b5b2..3280f75 100644
--- a/tests/fuzztest/SConscript
+++ b/tests/fuzztest/SConscript
@@ -33,6 +33,7 @@
fuzz = malloc_env.Program(["fuzztest.c",
"random_data.c",
"validation.c",
+ "flakystream.c",
"alltypes_pointer.pb.c",
"alltypes_static.pb.c",
"alltypes_proto3_pointer.pb.c",
diff --git a/tests/fuzztest/flakystream.c b/tests/fuzztest/flakystream.c
new file mode 100644
index 0000000..3b10833
--- /dev/null
+++ b/tests/fuzztest/flakystream.c
@@ -0,0 +1,33 @@
+#include "flakystream.h"
+#include <string.h>
+
+bool flakystream_callback(pb_istream_t *stream, pb_byte_t *buf, size_t count)
+{
+ flakystream_t *state = stream->state;
+
+ if (state->position + count > state->msglen)
+ {
+ stream->bytes_left = 0;
+ return false;
+ }
+ else if (state->position + count > state->fail_after)
+ {
+ PB_RETURN_ERROR(stream, "flaky error");
+ }
+
+ memcpy(buf, state->buffer + state->position, count);
+ state->position += count;
+ return true;
+}
+
+void flakystream_init(flakystream_t *stream, const uint8_t *buffer, size_t msglen, size_t fail_after)
+{
+ memset(stream, 0, sizeof(*stream));
+ stream->stream.callback = flakystream_callback;
+ stream->stream.bytes_left = SIZE_MAX;
+ stream->stream.state = stream;
+ stream->buffer = buffer;
+ stream->position = 0;
+ stream->msglen = msglen;
+ stream->fail_after = fail_after;
+}
diff --git a/tests/fuzztest/flakystream.h b/tests/fuzztest/flakystream.h
new file mode 100644
index 0000000..6056ac4
--- /dev/null
+++ b/tests/fuzztest/flakystream.h
@@ -0,0 +1,22 @@
+/* This module implements a custom input stream that can be set to give IO error
+ * at specific point. */
+
+#ifndef FLAKYSTREAM_H
+#define FLAKYSTREAM_H
+
+#include <stdlib.h>
+#include <stdint.h>
+#include <stdbool.h>
+#include <pb_decode.h>
+
+typedef struct {
+ pb_istream_t stream;
+ const uint8_t *buffer;
+ size_t position;
+ size_t msglen;
+ size_t fail_after;
+} flakystream_t;
+
+void flakystream_init(flakystream_t *stream, const uint8_t *buffer, size_t msglen, size_t fail_after);
+
+#endif
diff --git a/tests/fuzztest/fuzztest.c b/tests/fuzztest/fuzztest.c
index 00fd76c..9d86d2c 100644
--- a/tests/fuzztest/fuzztest.c
+++ b/tests/fuzztest/fuzztest.c
@@ -16,6 +16,7 @@
#include <malloc_wrappers.h>
#include "random_data.h"
#include "validation.h"
+#include "flakystream.h"
#include "test_helpers.h"
#include "alltypes_static.pb.h"
#include "alltypes_pointer.pb.h"
@@ -60,6 +61,35 @@
return status;
}
+static bool do_stream_decode(const uint8_t *buffer, size_t msglen, size_t fail_after, size_t structsize, const pb_msgdesc_t *msgtype, bool assert_success)
+{
+ bool status;
+ flakystream_t stream;
+ size_t initial_alloc_count = get_alloc_count();
+ void *msg = malloc_with_check(structsize);
+
+ memset(msg, 0, structsize);
+ flakystream_init(&stream, buffer, msglen, fail_after);
+ status = pb_decode(&stream.stream, msgtype, msg);
+
+ if (status)
+ {
+ validate_message(msg, structsize, msgtype);
+ }
+
+ if (assert_success)
+ {
+ if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream.stream));
+ assert(status);
+ }
+
+ pb_release(msgtype, msg);
+ free_with_check(msg);
+ assert(get_alloc_count() == initial_alloc_count);
+
+ return status;
+}
+
/* Do a decode -> encode -> decode -> encode roundtrip */
static void do_roundtrip(const uint8_t *buffer, size_t msglen, size_t structsize, const pb_msgdesc_t *msgtype)
{
@@ -149,6 +179,10 @@
/* Any message that is decodeable as proto2 should be decodeable as proto3 */
do_roundtrip(data, size, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields);
do_roundtrip(data, size, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields);
+
+ /* Test successful decoding also when using a stream */
+ do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, true);
+ do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields, true);
}
else if (do_decode(data, size, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields, expect_valid))
{
@@ -160,12 +194,22 @@
{
do_roundtrip(data, size, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields);
do_roundtrip(data, size, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields);
+
+ /* Test successful decoding also when using a stream */
+ do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields, true);
+ do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields, true);
}
else if (do_decode(data, size, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields, expect_valid))
{
do_roundtrip(data, size, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields);
}
+ /* Test decoding on a failing stream */
+ do_stream_decode(data, size, size - 16, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, false);
+ do_stream_decode(data, size, size - 16, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields, false);
+ do_stream_decode(data, size, size - 16, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields, false);
+ do_stream_decode(data, size, size - 16, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields, false);
+
assert(get_alloc_count() == initial_alloc_count);
}