pw_protobuf: FieldKey class
This replaces the MakeKey function with a FieldKey class, which can be
used to both create and parse a protobuf field key.
Change-Id: Ib37265fd539feabd923e558a395a6f1a1856d915
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/59641
Commit-Queue: Yecheng Zhao <zyecheng@google.com>
Reviewed-by: David Rogers <davidrogers@google.com>
diff --git a/pw_protobuf/decoder.cc b/pw_protobuf/decoder.cc
index 813046c..7f8c8cf 100644
--- a/pw_protobuf/decoder.cc
+++ b/pw_protobuf/decoder.cc
@@ -50,7 +50,10 @@
uint32_t Decoder::FieldNumber() const {
uint64_t key;
varint::Decode(proto_, &key);
- return key >> kFieldNumberShift;
+ if (!FieldKey::IsValidKey(key)) {
+ return 0;
+ }
+ return FieldKey(key).field_number();
}
Status Decoder::ReadUint32(uint32_t* out) {
@@ -113,16 +116,15 @@
size_t Decoder::FieldSize() const {
uint64_t key;
size_t key_size = varint::Decode(proto_, &key);
- if (key_size == 0) {
+ if (key_size == 0 || !FieldKey::IsValidKey(key)) {
return 0;
}
std::span<const std::byte> remainder = proto_.subspan(key_size);
- WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
uint64_t value = 0;
size_t expected_size = 0;
- switch (wire_type) {
+ switch (FieldKey(key).wire_type()) {
case WireType::kVarint:
expected_size = varint::Decode(remainder, &value);
if (expected_size == 0) {
@@ -162,8 +164,11 @@
return Status::FailedPrecondition();
}
- WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
- if (wire_type != expected_type) {
+ if (!FieldKey::IsValidKey(key)) {
+ return Status::DataLoss();
+ }
+
+ if (FieldKey(key).wire_type() != expected_type) {
return Status::FailedPrecondition();
}
diff --git a/pw_protobuf/decoder_test.cc b/pw_protobuf/decoder_test.cc
index 90d644a..d006312 100644
--- a/pw_protobuf/decoder_test.cc
+++ b/pw_protobuf/decoder_test.cc
@@ -162,6 +162,32 @@
EXPECT_EQ(decoder.Next(), Status::OutOfRange());
}
+TEST(Decoder, Decode_BadFieldNumber) {
+ // clang-format off
+ constexpr uint8_t encoded_proto[] = {
+ // type=int32, k=1, v=42
+ 0x08, 0x2a,
+ // type=int32, k=19001, v=42 (invalid field number)
+ 0xc8, 0xa3, 0x09, 0x2a,
+ // type=bool, k=3, v=false
+ 0x18, 0x00,
+ };
+ // clang-format on
+
+ Decoder decoder(std::as_bytes(std::span(encoded_proto)));
+ int32_t value;
+
+ EXPECT_EQ(decoder.Next(), OkStatus());
+ EXPECT_EQ(decoder.FieldNumber(), 1u);
+ ASSERT_EQ(decoder.ReadInt32(&value), OkStatus());
+ EXPECT_EQ(value, 42);
+
+ // Bad field.
+ EXPECT_EQ(decoder.Next(), Status::DataLoss());
+ EXPECT_EQ(decoder.FieldNumber(), 0u);
+ EXPECT_EQ(decoder.ReadInt32(&value), Status::DataLoss());
+}
+
TEST(CallbackDecoder, Decode) {
CallbackDecoder decoder;
TestDecodeHandler handler;
diff --git a/pw_protobuf/encoder.cc b/pw_protobuf/encoder.cc
index b427e92..93cc777 100644
--- a/pw_protobuf/encoder.cc
+++ b/pw_protobuf/encoder.cc
@@ -31,12 +31,14 @@
StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number) {
PW_CHECK(!nested_encoder_open());
+ PW_CHECK(ValidFieldNumber(field_number));
+
nested_field_number_ = field_number;
// Pass the unused space of the scratch buffer to the nested encoder to use
// as their scratch buffer.
size_t key_size =
- varint::EncodedSize(MakeKey(field_number, WireType::kDelimited));
+ varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
size_t reserved_size = key_size + config::kMaxVarintSize;
size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
writer_.ConservativeWriteLimit());
@@ -109,7 +111,7 @@
PW_TRY(UpdateStatusForWrite(
field_number, WireType::kVarint, varint::EncodedSize(value)));
- WriteVarint(MakeKey(field_number, WireType::kVarint))
+ WriteVarint(FieldKey(field_number, WireType::kVarint))
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
return WriteVarint(value);
}
@@ -117,7 +119,7 @@
Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
ConstByteSpan data) {
PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
- WriteVarint(MakeKey(field_number, WireType::kDelimited))
+ WriteVarint(FieldKey(field_number, WireType::kDelimited))
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
WriteVarint(data.size_bytes())
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
@@ -133,15 +135,15 @@
size_t num_bytes,
ByteSpan stream_pipe_buffer) {
PW_CHECK_UINT_GT(
- stream_pipe_buffer.size(), 0, "transfer buffer cannot be 0 size");
+ stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
// Ignore the error until we explicitly check status_ below to minimize
// the number of branches.
- WriteVarint(MakeKey(field_number, WireType::kDelimited)).IgnoreError();
+ WriteVarint(FieldKey(field_number, WireType::kDelimited)).IgnoreError();
WriteVarint(num_bytes).IgnoreError();
PW_TRY(status_);
- // Stream data from `bytes_reader` to `writer_`
+ // Stream data from `bytes_reader` to `writer_`.
// TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
// time.
for (size_t bytes_written = 0; bytes_written < num_bytes;) {
@@ -167,7 +169,7 @@
PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
- WriteVarint(MakeKey(field_number, type))
+ WriteVarint(FieldKey(field_number, type))
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
if (Status status = writer_.Write(data); !status.ok()) {
status_ = status;
@@ -209,7 +211,7 @@
PW_TRY(UpdateStatusForWrite(
field_number, WireType::kDelimited, values.size_bytes()));
- WriteVarint(MakeKey(field_number, WireType::kDelimited))
+ WriteVarint(FieldKey(field_number, WireType::kDelimited))
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
WriteVarint(values.size_bytes())
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
@@ -242,7 +244,7 @@
return status_;
}
- size_t size = varint::EncodedSize(MakeKey(field_number, type));
+ size_t size = varint::EncodedSize(FieldKey(field_number, type));
if (type == WireType::kDelimited) {
size += varint::EncodedSize(data_size);
}
diff --git a/pw_protobuf/encoder_test.cc b/pw_protobuf/encoder_test.cc
index be93bb3..1ab005a 100644
--- a/pw_protobuf/encoder_test.cc
+++ b/pw_protobuf/encoder_test.cc
@@ -392,7 +392,7 @@
ASSERT_EQ(parent.status(), OkStatus());
const size_t kExpectedSize =
varint::EncodedSize(
- MakeKey(kTestProtoNestedField, WireType::kDelimited)) +
+ FieldKey(kTestProtoNestedField, WireType::kDelimited)) +
varint::EncodedSize(0);
ASSERT_EQ(parent.size(), kExpectedSize);
}
diff --git a/pw_protobuf/public/pw_protobuf/decoder.h b/pw_protobuf/public/pw_protobuf/decoder.h
index baa076c..750ba94 100644
--- a/pw_protobuf/public/pw_protobuf/decoder.h
+++ b/pw_protobuf/public/pw_protobuf/decoder.h
@@ -64,6 +64,12 @@
Status Next();
// Returns the field number of the field at the current cursor position.
+ //
+ // A return value of 0 indicates that the field number is invalid. An invalid
+ // field number terminates the decode operation; any subsequent calls to
+ // Next() or Read*() will return DATA_LOSS.
+ //
+ // TODO(frolv): This should be refactored to return a Result<uint32_t>.
uint32_t FieldNumber() const;
// Reads a proto int32 value from the current cursor.
diff --git a/pw_protobuf/public/pw_protobuf/encoder.h b/pw_protobuf/public/pw_protobuf/encoder.h
index 5611773..72b2147 100644
--- a/pw_protobuf/public/pw_protobuf/encoder.h
+++ b/pw_protobuf/public/pw_protobuf/encoder.h
@@ -507,7 +507,7 @@
return status_;
}
- WriteVarint(MakeKey(field_number, WireType::kDelimited))
+ WriteVarint(FieldKey(field_number, WireType::kDelimited))
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
WriteVarint(payload_size)
.IgnoreError(); // TODO(pwbug/387): Handle Status properly
diff --git a/pw_protobuf/public/pw_protobuf/serialized_size.h b/pw_protobuf/public/pw_protobuf/serialized_size.h
index d671683..dc51491 100644
--- a/pw_protobuf/public/pw_protobuf/serialized_size.h
+++ b/pw_protobuf/public/pw_protobuf/serialized_size.h
@@ -45,8 +45,9 @@
inline constexpr size_t kMaxSizeOfLength = varint::kMaxVarint32SizeBytes;
constexpr size_t SizeOfFieldKey(uint32_t field_number) {
- // The wiretype is ignored as this does not impact the serialized size.
- return varint::EncodedSize(field_number << kFieldNumberShift);
+ // The wiretype does not impact the serialized size, so use kVarint (0), which
+ // will be optimized out by the compiler.
+ return varint::EncodedSize(FieldKey(field_number, WireType::kVarint));
}
} // namespace pw::protobuf
diff --git a/pw_protobuf/public/pw_protobuf/wire_format.h b/pw_protobuf/public/pw_protobuf/wire_format.h
index fa6071c..a222d7f 100644
--- a/pw_protobuf/public/pw_protobuf/wire_format.h
+++ b/pw_protobuf/public/pw_protobuf/wire_format.h
@@ -14,6 +14,9 @@
#pragma once
#include <cstdint>
+#include <limits>
+
+#include "pw_assert/assert.h"
namespace pw::protobuf {
@@ -24,6 +27,19 @@
constexpr static uint32_t kFirstReservedNumber = 19000;
constexpr static uint32_t kLastReservedNumber = 19999;
+constexpr bool ValidFieldNumber(uint32_t field_number) {
+ return field_number != 0 && field_number <= kMaxFieldNumber &&
+ !(field_number >= kFirstReservedNumber &&
+ field_number <= kLastReservedNumber);
+}
+
+constexpr bool ValidFieldNumber(uint64_t field_number) {
+ if (field_number > std::numeric_limits<uint32_t>::max()) {
+ return false;
+ }
+ return ValidFieldNumber(static_cast<uint32_t>(field_number));
+}
+
enum class WireType {
kVarint = 0,
kFixed64 = 1,
@@ -32,17 +48,51 @@
kFixed32 = 5,
};
-inline constexpr unsigned int kFieldNumberShift = 3u;
-inline constexpr unsigned int kWireTypeMask = (1u << kFieldNumberShift) - 1u;
+// Represents a protobuf field key, storing a field number and wire type.
+class FieldKey {
+ public:
+ // Checks if the given encoded protobuf key is valid. Must be called before
+ // instantiating a FieldKey object with it.
+ static constexpr bool IsValidKey(uint64_t key) {
+ uint64_t field_number = key >> kFieldNumberShift;
+ uint32_t wire_type = key & kWireTypeMask;
-constexpr uint32_t MakeKey(uint32_t field_number, WireType wire_type) {
- return (field_number << kFieldNumberShift | static_cast<uint32_t>(wire_type));
-}
+ return ValidFieldNumber(field_number) && (wire_type <= 2 || wire_type == 5);
+ }
-constexpr bool ValidFieldNumber(uint32_t field_number) {
- return field_number != 0 && field_number <= kMaxFieldNumber &&
- !(field_number >= kFirstReservedNumber &&
- field_number <= kLastReservedNumber);
+ // Creates a field key with the given field number and type.
+ //
+ // Precondition: The field number is valid.
+ constexpr FieldKey(uint32_t field_number, WireType wire_type)
+ : key_(field_number << kFieldNumberShift |
+ static_cast<uint32_t>(wire_type)) {
+ PW_DASSERT(ValidFieldNumber(field_number));
+ }
+
+ // Parses a field key from its encoded representation.
+ //
+ // Precondition: The field number is valid. Call IsValidKey(key) first.
+ constexpr FieldKey(uint32_t key) : key_(key) {
+ PW_DASSERT(ValidFieldNumber(field_number()));
+ }
+
+ constexpr operator uint32_t() { return key_; }
+
+ constexpr uint32_t field_number() const { return key_ >> kFieldNumberShift; }
+ constexpr WireType wire_type() const {
+ return static_cast<WireType>(key_ & kWireTypeMask);
+ }
+
+ private:
+ static constexpr unsigned int kFieldNumberShift = 3u;
+ static constexpr unsigned int kWireTypeMask = (1u << kFieldNumberShift) - 1u;
+
+ uint32_t key_;
+};
+
+[[deprecated("Use the FieldKey class")]] constexpr uint32_t MakeKey(
+ uint32_t field_number, WireType wire_type) {
+ return FieldKey(field_number, wire_type);
}
} // namespace pw::protobuf
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc
index 6a79378..4a405b4 100644
--- a/pw_rpc/packet_test.cc
+++ b/pw_rpc/packet_test.cc
@@ -21,13 +21,14 @@
namespace pw::rpc::internal {
namespace {
+using protobuf::FieldKey;
using std::byte;
constexpr auto kPayload = bytes::Array<0x82, 0x02, 0xff, 0xff>();
constexpr auto kEncoded = bytes::Array<
// Payload
- MakeKey(5, protobuf::WireType::kDelimited),
+ uint32_t(FieldKey(5, protobuf::WireType::kDelimited)),
0x04,
0x82,
0x02,
@@ -35,29 +36,29 @@
0xff,
// Packet type
- MakeKey(1, protobuf::WireType::kVarint),
+ uint32_t(FieldKey(1, protobuf::WireType::kVarint)),
1, // RESPONSE
// Channel ID
- MakeKey(2, protobuf::WireType::kVarint),
+ uint32_t(FieldKey(2, protobuf::WireType::kVarint)),
1,
// Service ID
- MakeKey(3, protobuf::WireType::kFixed32),
+ uint32_t(FieldKey(3, protobuf::WireType::kFixed32)),
42,
0,
0,
0,
// Method ID
- MakeKey(4, protobuf::WireType::kFixed32),
+ uint32_t(FieldKey(4, protobuf::WireType::kFixed32)),
100,
0,
0,
0
// Status (not encoded if it is zero)
- // MakeKey(6, protobuf::WireType::kVarint),
+ // FieldKey(6, protobuf::WireType::kVarint),
// 0x00
>();