pw_protobuf: FieldKey class

This replaces the MakeKey function with a FieldKey class, which can be
used to both create and parse a protobuf field key.

Change-Id: Ib37265fd539feabd923e558a395a6f1a1856d915
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/59641
Commit-Queue: Yecheng Zhao <zyecheng@google.com>
Reviewed-by: David Rogers <davidrogers@google.com>
diff --git a/pw_protobuf/decoder.cc b/pw_protobuf/decoder.cc
index 813046c..7f8c8cf 100644
--- a/pw_protobuf/decoder.cc
+++ b/pw_protobuf/decoder.cc
@@ -50,7 +50,10 @@
 uint32_t Decoder::FieldNumber() const {
   uint64_t key;
   varint::Decode(proto_, &key);
-  return key >> kFieldNumberShift;
+  if (!FieldKey::IsValidKey(key)) {
+    return 0;
+  }
+  return FieldKey(key).field_number();
 }
 
 Status Decoder::ReadUint32(uint32_t* out) {
@@ -113,16 +116,15 @@
 size_t Decoder::FieldSize() const {
   uint64_t key;
   size_t key_size = varint::Decode(proto_, &key);
-  if (key_size == 0) {
+  if (key_size == 0 || !FieldKey::IsValidKey(key)) {
     return 0;
   }
 
   std::span<const std::byte> remainder = proto_.subspan(key_size);
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
   uint64_t value = 0;
   size_t expected_size = 0;
 
-  switch (wire_type) {
+  switch (FieldKey(key).wire_type()) {
     case WireType::kVarint:
       expected_size = varint::Decode(remainder, &value);
       if (expected_size == 0) {
@@ -162,8 +164,11 @@
     return Status::FailedPrecondition();
   }
 
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
-  if (wire_type != expected_type) {
+  if (!FieldKey::IsValidKey(key)) {
+    return Status::DataLoss();
+  }
+
+  if (FieldKey(key).wire_type() != expected_type) {
     return Status::FailedPrecondition();
   }
 
diff --git a/pw_protobuf/decoder_test.cc b/pw_protobuf/decoder_test.cc
index 90d644a..d006312 100644
--- a/pw_protobuf/decoder_test.cc
+++ b/pw_protobuf/decoder_test.cc
@@ -162,6 +162,32 @@
   EXPECT_EQ(decoder.Next(), Status::OutOfRange());
 }
 
+TEST(Decoder, Decode_BadFieldNumber) {
+  // clang-format off
+  constexpr uint8_t encoded_proto[] = {
+    // type=int32, k=1, v=42
+    0x08, 0x2a,
+    // type=int32, k=19001, v=42 (invalid field number)
+    0xc8, 0xa3, 0x09, 0x2a,
+    // type=bool, k=3, v=false
+    0x18, 0x00,
+  };
+  // clang-format on
+
+  Decoder decoder(std::as_bytes(std::span(encoded_proto)));
+  int32_t value;
+
+  EXPECT_EQ(decoder.Next(), OkStatus());
+  EXPECT_EQ(decoder.FieldNumber(), 1u);
+  ASSERT_EQ(decoder.ReadInt32(&value), OkStatus());
+  EXPECT_EQ(value, 42);
+
+  // Bad field.
+  EXPECT_EQ(decoder.Next(), Status::DataLoss());
+  EXPECT_EQ(decoder.FieldNumber(), 0u);
+  EXPECT_EQ(decoder.ReadInt32(&value), Status::DataLoss());
+}
+
 TEST(CallbackDecoder, Decode) {
   CallbackDecoder decoder;
   TestDecodeHandler handler;
diff --git a/pw_protobuf/encoder.cc b/pw_protobuf/encoder.cc
index b427e92..93cc777 100644
--- a/pw_protobuf/encoder.cc
+++ b/pw_protobuf/encoder.cc
@@ -31,12 +31,14 @@
 
 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number) {
   PW_CHECK(!nested_encoder_open());
+  PW_CHECK(ValidFieldNumber(field_number));
+
   nested_field_number_ = field_number;
 
   // Pass the unused space of the scratch buffer to the nested encoder to use
   // as their scratch buffer.
   size_t key_size =
-      varint::EncodedSize(MakeKey(field_number, WireType::kDelimited));
+      varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
   size_t reserved_size = key_size + config::kMaxVarintSize;
   size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
                              writer_.ConservativeWriteLimit());
@@ -109,7 +111,7 @@
   PW_TRY(UpdateStatusForWrite(
       field_number, WireType::kVarint, varint::EncodedSize(value)));
 
-  WriteVarint(MakeKey(field_number, WireType::kVarint))
+  WriteVarint(FieldKey(field_number, WireType::kVarint))
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
   return WriteVarint(value);
 }
@@ -117,7 +119,7 @@
 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
                                                 ConstByteSpan data) {
   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
-  WriteVarint(MakeKey(field_number, WireType::kDelimited))
+  WriteVarint(FieldKey(field_number, WireType::kDelimited))
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
   WriteVarint(data.size_bytes())
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
@@ -133,15 +135,15 @@
     size_t num_bytes,
     ByteSpan stream_pipe_buffer) {
   PW_CHECK_UINT_GT(
-      stream_pipe_buffer.size(), 0, "transfer buffer cannot be 0 size");
+      stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
   // Ignore the error until we explicitly check status_ below to minimize
   // the number of branches.
-  WriteVarint(MakeKey(field_number, WireType::kDelimited)).IgnoreError();
+  WriteVarint(FieldKey(field_number, WireType::kDelimited)).IgnoreError();
   WriteVarint(num_bytes).IgnoreError();
   PW_TRY(status_);
 
-  // Stream data from `bytes_reader` to `writer_`
+  // Stream data from `bytes_reader` to `writer_`.
   // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
   // time.
   for (size_t bytes_written = 0; bytes_written < num_bytes;) {
@@ -167,7 +169,7 @@
 
   PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
 
-  WriteVarint(MakeKey(field_number, type))
+  WriteVarint(FieldKey(field_number, type))
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
   if (Status status = writer_.Write(data); !status.ok()) {
     status_ = status;
@@ -209,7 +211,7 @@
 
   PW_TRY(UpdateStatusForWrite(
       field_number, WireType::kDelimited, values.size_bytes()));
-  WriteVarint(MakeKey(field_number, WireType::kDelimited))
+  WriteVarint(FieldKey(field_number, WireType::kDelimited))
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
   WriteVarint(values.size_bytes())
       .IgnoreError();  // TODO(pwbug/387): Handle Status properly
@@ -242,7 +244,7 @@
     return status_;
   }
 
-  size_t size = varint::EncodedSize(MakeKey(field_number, type));
+  size_t size = varint::EncodedSize(FieldKey(field_number, type));
   if (type == WireType::kDelimited) {
     size += varint::EncodedSize(data_size);
   }
diff --git a/pw_protobuf/encoder_test.cc b/pw_protobuf/encoder_test.cc
index be93bb3..1ab005a 100644
--- a/pw_protobuf/encoder_test.cc
+++ b/pw_protobuf/encoder_test.cc
@@ -392,7 +392,7 @@
   ASSERT_EQ(parent.status(), OkStatus());
   const size_t kExpectedSize =
       varint::EncodedSize(
-          MakeKey(kTestProtoNestedField, WireType::kDelimited)) +
+          FieldKey(kTestProtoNestedField, WireType::kDelimited)) +
       varint::EncodedSize(0);
   ASSERT_EQ(parent.size(), kExpectedSize);
 }
diff --git a/pw_protobuf/public/pw_protobuf/decoder.h b/pw_protobuf/public/pw_protobuf/decoder.h
index baa076c..750ba94 100644
--- a/pw_protobuf/public/pw_protobuf/decoder.h
+++ b/pw_protobuf/public/pw_protobuf/decoder.h
@@ -64,6 +64,12 @@
   Status Next();
 
   // Returns the field number of the field at the current cursor position.
+  //
+  // A return value of 0 indicates that the field number is invalid. An invalid
+  // field number terminates the decode operation; any subsequent calls to
+  // Next() or Read*() will return DATA_LOSS.
+  //
+  // TODO(frolv): This should be refactored to return a Result<uint32_t>.
   uint32_t FieldNumber() const;
 
   // Reads a proto int32 value from the current cursor.
diff --git a/pw_protobuf/public/pw_protobuf/encoder.h b/pw_protobuf/public/pw_protobuf/encoder.h
index 5611773..72b2147 100644
--- a/pw_protobuf/public/pw_protobuf/encoder.h
+++ b/pw_protobuf/public/pw_protobuf/encoder.h
@@ -507,7 +507,7 @@
       return status_;
     }
 
-    WriteVarint(MakeKey(field_number, WireType::kDelimited))
+    WriteVarint(FieldKey(field_number, WireType::kDelimited))
         .IgnoreError();  // TODO(pwbug/387): Handle Status properly
     WriteVarint(payload_size)
         .IgnoreError();  // TODO(pwbug/387): Handle Status properly
diff --git a/pw_protobuf/public/pw_protobuf/serialized_size.h b/pw_protobuf/public/pw_protobuf/serialized_size.h
index d671683..dc51491 100644
--- a/pw_protobuf/public/pw_protobuf/serialized_size.h
+++ b/pw_protobuf/public/pw_protobuf/serialized_size.h
@@ -45,8 +45,9 @@
 inline constexpr size_t kMaxSizeOfLength = varint::kMaxVarint32SizeBytes;
 
 constexpr size_t SizeOfFieldKey(uint32_t field_number) {
-  // The wiretype is ignored as this does not impact the serialized size.
-  return varint::EncodedSize(field_number << kFieldNumberShift);
+  // The wiretype does not impact the serialized size, so use kVarint (0), which
+  // will be optimized out by the compiler.
+  return varint::EncodedSize(FieldKey(field_number, WireType::kVarint));
 }
 
 }  // namespace pw::protobuf
diff --git a/pw_protobuf/public/pw_protobuf/wire_format.h b/pw_protobuf/public/pw_protobuf/wire_format.h
index fa6071c..a222d7f 100644
--- a/pw_protobuf/public/pw_protobuf/wire_format.h
+++ b/pw_protobuf/public/pw_protobuf/wire_format.h
@@ -14,6 +14,9 @@
 #pragma once
 
 #include <cstdint>
+#include <limits>
+
+#include "pw_assert/assert.h"
 
 namespace pw::protobuf {
 
@@ -24,6 +27,19 @@
 constexpr static uint32_t kFirstReservedNumber = 19000;
 constexpr static uint32_t kLastReservedNumber = 19999;
 
+constexpr bool ValidFieldNumber(uint32_t field_number) {
+  return field_number != 0 && field_number <= kMaxFieldNumber &&
+         !(field_number >= kFirstReservedNumber &&
+           field_number <= kLastReservedNumber);
+}
+
+constexpr bool ValidFieldNumber(uint64_t field_number) {
+  if (field_number > std::numeric_limits<uint32_t>::max()) {
+    return false;
+  }
+  return ValidFieldNumber(static_cast<uint32_t>(field_number));
+}
+
 enum class WireType {
   kVarint = 0,
   kFixed64 = 1,
@@ -32,17 +48,51 @@
   kFixed32 = 5,
 };
 
-inline constexpr unsigned int kFieldNumberShift = 3u;
-inline constexpr unsigned int kWireTypeMask = (1u << kFieldNumberShift) - 1u;
+// Represents a protobuf field key, storing a field number and wire type.
+class FieldKey {
+ public:
+  // Checks if the given encoded protobuf key is valid. Must be called before
+  // instantiating a FieldKey object with it.
+  static constexpr bool IsValidKey(uint64_t key) {
+    uint64_t field_number = key >> kFieldNumberShift;
+    uint32_t wire_type = key & kWireTypeMask;
 
-constexpr uint32_t MakeKey(uint32_t field_number, WireType wire_type) {
-  return (field_number << kFieldNumberShift | static_cast<uint32_t>(wire_type));
-}
+    return ValidFieldNumber(field_number) && (wire_type <= 2 || wire_type == 5);
+  }
 
-constexpr bool ValidFieldNumber(uint32_t field_number) {
-  return field_number != 0 && field_number <= kMaxFieldNumber &&
-         !(field_number >= kFirstReservedNumber &&
-           field_number <= kLastReservedNumber);
+  // Creates a field key with the given field number and type.
+  //
+  // Precondition: The field number is valid.
+  constexpr FieldKey(uint32_t field_number, WireType wire_type)
+      : key_(field_number << kFieldNumberShift |
+             static_cast<uint32_t>(wire_type)) {
+    PW_DASSERT(ValidFieldNumber(field_number));
+  }
+
+  // Parses a field key from its encoded representation.
+  //
+  // Precondition: The field number is valid. Call IsValidKey(key) first.
+  constexpr FieldKey(uint32_t key) : key_(key) {
+    PW_DASSERT(ValidFieldNumber(field_number()));
+  }
+
+  constexpr operator uint32_t() { return key_; }
+
+  constexpr uint32_t field_number() const { return key_ >> kFieldNumberShift; }
+  constexpr WireType wire_type() const {
+    return static_cast<WireType>(key_ & kWireTypeMask);
+  }
+
+ private:
+  static constexpr unsigned int kFieldNumberShift = 3u;
+  static constexpr unsigned int kWireTypeMask = (1u << kFieldNumberShift) - 1u;
+
+  uint32_t key_;
+};
+
+[[deprecated("Use the FieldKey class")]] constexpr uint32_t MakeKey(
+    uint32_t field_number, WireType wire_type) {
+  return FieldKey(field_number, wire_type);
 }
 
 }  // namespace pw::protobuf
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc
index 6a79378..4a405b4 100644
--- a/pw_rpc/packet_test.cc
+++ b/pw_rpc/packet_test.cc
@@ -21,13 +21,14 @@
 namespace pw::rpc::internal {
 namespace {
 
+using protobuf::FieldKey;
 using std::byte;
 
 constexpr auto kPayload = bytes::Array<0x82, 0x02, 0xff, 0xff>();
 
 constexpr auto kEncoded = bytes::Array<
     // Payload
-    MakeKey(5, protobuf::WireType::kDelimited),
+    uint32_t(FieldKey(5, protobuf::WireType::kDelimited)),
     0x04,
     0x82,
     0x02,
@@ -35,29 +36,29 @@
     0xff,
 
     // Packet type
-    MakeKey(1, protobuf::WireType::kVarint),
+    uint32_t(FieldKey(1, protobuf::WireType::kVarint)),
     1,  // RESPONSE
 
     // Channel ID
-    MakeKey(2, protobuf::WireType::kVarint),
+    uint32_t(FieldKey(2, protobuf::WireType::kVarint)),
     1,
 
     // Service ID
-    MakeKey(3, protobuf::WireType::kFixed32),
+    uint32_t(FieldKey(3, protobuf::WireType::kFixed32)),
     42,
     0,
     0,
     0,
 
     // Method ID
-    MakeKey(4, protobuf::WireType::kFixed32),
+    uint32_t(FieldKey(4, protobuf::WireType::kFixed32)),
     100,
     0,
     0,
     0
 
     // Status (not encoded if it is zero)
-    // MakeKey(6, protobuf::WireType::kVarint),
+    // FieldKey(6, protobuf::WireType::kVarint),
     // 0x00
     >();