Add CBOR in-place bstr/tstr writing functions

When the size of data is known, it can be written into place without
needing to use a temporary buffer. Add functions to support this for
binary and text strings.

Change-Id: I633eb9c1d960f0193de2bcd129e36e387849be73
Reviewed-on: https://pigweed-review.googlesource.com/c/open-dice/+/46141
Commit-Queue: Andrew Scull <ascull@google.com>
Pigweed-Auto-Submit: Andrew Scull <ascull@google.com>
Reviewed-by: Darren Krahn <dkrahn@google.com>
diff --git a/include/dice/cbor_writer.h b/include/dice/cbor_writer.h
index 700d57f..304f9f3 100644
--- a/include/dice/cbor_writer.h
+++ b/include/dice/cbor_writer.h
@@ -69,6 +69,12 @@
 void CborWriteTrue(struct CborOut* out);
 void CborWriteNull(struct CborOut* out);
 
+// These functions write the type header and reserve space for the caller to
+// populate. The reserved space is left uninitialized. Returns NULL if space
+// could not be reserved in the output buffer.
+uint8_t* CborAllocBstr(size_t data_size, struct CborOut* out);
+char* CborAllocTstr(size_t size, struct CborOut* out);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/src/cbor_writer.c b/src/cbor_writer.c
index 9b29f76..986d1a1 100644
--- a/src/cbor_writer.c
+++ b/src/cbor_writer.c
@@ -88,17 +88,22 @@
   out->cursor += size;
 }
 
-static void CborWriteStr(enum CborType type, size_t data_size,
-                         const uint8_t* data, struct CborOut* out) {
+static void* CborAllocStr(enum CborType type, size_t data_size,
+                          struct CborOut* out) {
   CborWriteType(type, data_size, out);
-  if (CborWriteWouldOverflowCursor(data_size, out)) {
-    out->cursor = SIZE_MAX;
-    return;
+  bool overflow = CborWriteWouldOverflowCursor(data_size, out);
+  bool fit = CborWriteFitsInBuffer(data_size, out);
+  void* ptr = (overflow || !fit) ? NULL : &out->buffer[out->cursor];
+  out->cursor = overflow ? SIZE_MAX : out->cursor + data_size;
+  return ptr;
+}
+
+static void CborWriteStr(enum CborType type, size_t data_size, const void* data,
+                         struct CborOut* out) {
+  uint8_t* ptr = CborAllocStr(type, data_size, out);
+  if (ptr && data_size) {
+    memcpy(ptr, data, data_size);
   }
-  if (CborWriteFitsInBuffer(data_size, out) && data_size) {
-    memcpy(&out->buffer[out->cursor], data, data_size);
-  }
-  out->cursor += data_size;
 }
 
 void CborWriteInt(int64_t val, struct CborOut* out) {
@@ -113,8 +118,16 @@
   CborWriteStr(CBOR_TYPE_BSTR, data_size, data, out);
 }
 
+uint8_t* CborAllocBstr(size_t data_size, struct CborOut* out) {
+  return CborAllocStr(CBOR_TYPE_BSTR, data_size, out);
+}
+
 void CborWriteTstr(const char* str, struct CborOut* out) {
-  CborWriteStr(CBOR_TYPE_TSTR, strlen(str), (const uint8_t*)str, out);
+  CborWriteStr(CBOR_TYPE_TSTR, strlen(str), str, out);
+}
+
+char* CborAllocTstr(size_t size, struct CborOut* out) {
+  return CborAllocStr(CBOR_TYPE_TSTR, size, out);
 }
 
 void CborWriteArray(size_t num_elements, struct CborOut* out) {
diff --git a/src/cbor_writer_fuzzer.cc b/src/cbor_writer_fuzzer.cc
index bb4b3bd..afc4270 100644
--- a/src/cbor_writer_fuzzer.cc
+++ b/src/cbor_writer_fuzzer.cc
@@ -20,7 +20,9 @@
 enum CborWriterFunction {
   WriteInt,
   WriteBstr,
+  AllocBstr,
   WriteTstr,
+  AllocTstr,
   WriteArray,
   WriteMap,
   WriteFalse,
@@ -32,7 +34,7 @@
 // Use data sizes that exceed the 16-bit range without being excessive.
 constexpr size_t kMaxDataSize = 0xffff + 0x5000;
 constexpr size_t kMaxBufferSize = kMaxDataSize * 3;
-constexpr size_t kIterations = 16;
+constexpr size_t kIterations = 20;
 
 }  // namespace
 
@@ -56,6 +58,15 @@
         CborWriteBstr(bstr_data.size(), bstr_data.data(), &out);
         break;
       }
+      case AllocBstr: {
+        auto bstr_data_size =
+            fdp.ConsumeIntegralInRange<size_t>(0, kMaxDataSize);
+        uint8_t* ptr = CborAllocBstr(bstr_data_size, &out);
+        if (ptr) {
+          memset(ptr, 0x5a, bstr_data_size);
+        }
+        break;
+      }
       case WriteTstr: {
         auto tstr_data_size =
             fdp.ConsumeIntegralInRange<size_t>(0, kMaxDataSize);
@@ -63,6 +74,15 @@
         CborWriteTstr(str.c_str(), &out);
         break;
       }
+      case AllocTstr: {
+        auto tstr_data_size =
+            fdp.ConsumeIntegralInRange<size_t>(0, kMaxDataSize);
+        char* str = CborAllocTstr(tstr_data_size, &out);
+        if (str) {
+          memset(str, 'q', tstr_data_size);
+        }
+        break;
+      }
       case WriteArray: {
         auto num_elements = fdp.ConsumeIntegral<size_t>();
         CborWriteArray(num_elements, &out);
diff --git a/src/cbor_writer_test.cc b/src/cbor_writer_test.cc
index 1bcfd59..a92b1fc 100644
--- a/src/cbor_writer_test.cc
+++ b/src/cbor_writer_test.cc
@@ -117,6 +117,20 @@
   EXPECT_EQ(0, memcmp(buffer, kExpectedEncoding, sizeof(kExpectedEncoding)));
 }
 
+TEST(CborWriterTest, BstrAllocEncoding) {
+  const uint8_t kExpectedEncoding[] = {0x45, 'a', 'l', 'l', 'o', 'c'};
+  const uint8_t kData[] = {'a', 'l', 'l', 'o', 'c'};
+  uint8_t buffer[64];
+  uint8_t* ptr;
+  CborOut out;
+  CborOutInit(buffer, sizeof(buffer), &out);
+  ptr = CborAllocBstr(sizeof(kData), &out);
+  EXPECT_NE(nullptr, ptr);
+  EXPECT_FALSE(CborOutOverflowed(&out));
+  memcpy(ptr, kData, sizeof(kData));
+  EXPECT_EQ(0, memcmp(buffer, kExpectedEncoding, sizeof(kExpectedEncoding)));
+}
+
 TEST(CborWriterTest, TstrEncoding) {
   const uint8_t kExpectedEncoding[] = {0x65, 'w', 'o', 'r', 'l', 'd'};
   uint8_t buffer[64];
@@ -127,6 +141,20 @@
   EXPECT_EQ(0, memcmp(buffer, kExpectedEncoding, sizeof(kExpectedEncoding)));
 }
 
+TEST(CborWriterTest, TstrAllocEncoding) {
+  const uint8_t kExpectedEncoding[] = {0x65, 's', 'p', 'a', 'c', 'e'};
+  const char kStr[] = "space";
+  char* ptr;
+  uint8_t buffer[64];
+  CborOut out;
+  CborOutInit(buffer, sizeof(buffer), &out);
+  ptr = CborAllocTstr(strlen(kStr), &out);
+  EXPECT_NE(nullptr, ptr);
+  EXPECT_FALSE(CborOutOverflowed(&out));
+  memcpy(ptr, kStr, sizeof(kStr));
+  EXPECT_EQ(0, memcmp(buffer, kExpectedEncoding, sizeof(kExpectedEncoding)));
+}
+
 TEST(CborWriterTest, ArrayEncoding) {
   const uint8_t kExpectedEncoding[] = {0x98, 29};
   uint8_t buffer[64];
@@ -184,7 +212,9 @@
   CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteInt(0xab34, &out);
   CborWriteBstr(sizeof(kData), kData, &out);
+  EXPECT_NE(nullptr, CborAllocBstr(7, &out));
   CborWriteTstr("A string", &out);
+  EXPECT_NE(nullptr, CborAllocTstr(6, &out));
   CborWriteArray(/*num_elements=*/16, &out);
   CborWriteMap(/*num_pairs=*/35, &out);
   CborWriteFalse(&out);
@@ -192,7 +222,7 @@
   CborWriteNull(&out);
   EXPECT_FALSE(CborOutOverflowed(&out));
   // Offset is the cumulative size.
-  EXPECT_EQ(3 + 6 + 9 + 1 + 2 + 1 + 1 + 1u, CborOutSize(&out));
+  EXPECT_EQ(3 + 6 + 8 + 9 + 7 + 1 + 2 + 1 + 1 + 1u, CborOutSize(&out));
 }
 
 TEST(CborWriterTest, NullBufferForMeasurement) {
@@ -204,13 +234,15 @@
   CborWriteFalse(&out);
   CborWriteMap(/*num_pairs=*/623, &out);
   CborWriteArray(/*num_elements=*/70000, &out);
+  EXPECT_EQ(nullptr, CborAllocTstr(8, &out));
   CborWriteTstr("length", &out);
+  EXPECT_EQ(nullptr, CborAllocBstr(1, &out));
   CborWriteBstr(sizeof(kData), kData, &out);
   CborWriteInt(-10002000, &out);
   // Measurement has occurred, but output did not.
   EXPECT_TRUE(CborOutOverflowed(&out));
   // Offset is the cumulative size.
-  EXPECT_EQ(1 + 1 + 1 + 3 + 5 + 7 + 8 + 5u, CborOutSize(&out));
+  EXPECT_EQ(1 + 1 + 1 + 3 + 5 + 9 + 7 + 2 + 8 + 5u, CborOutSize(&out));
 }
 
 TEST(CborWriterTest, BufferTooSmall) {
@@ -226,9 +258,15 @@
   CborWriteBstr(sizeof(kData), kData, &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
+  EXPECT_EQ(nullptr, CborAllocBstr(sizeof(kData), &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteTstr("Buffer too small", &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
+  EXPECT_EQ(nullptr, CborAllocTstr(16, &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteArray(/*num_elements=*/563, &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
@@ -260,10 +298,18 @@
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteBstr(sizeof(buffer) - 3, zeros, &out);
+  EXPECT_EQ(nullptr, CborAllocBstr(sizeof(kData), &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
+  CborWriteBstr(sizeof(buffer) - 3, zeros, &out);
   CborWriteTstr("Won't fit", &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteBstr(sizeof(buffer) - 3, zeros, &out);
+  EXPECT_EQ(nullptr, CborAllocTstr(4, &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
+  CborWriteBstr(sizeof(buffer) - 3, zeros, &out);
   CborWriteArray(/*num_elements=*/352, &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
@@ -298,10 +344,18 @@
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
+  EXPECT_EQ(nullptr, CborAllocBstr(sizeof(kData), &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
+  CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
   CborWriteTstr("Overflow", &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
   CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
+  EXPECT_EQ(nullptr, CborAllocTstr(4, &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(buffer, sizeof(buffer), &out);
+  CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
   CborWriteArray(/*num_elements=*/41, &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(buffer, sizeof(buffer), &out);
@@ -335,10 +389,18 @@
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(nullptr, 0, &out);
   CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
+  EXPECT_EQ(nullptr, CborAllocBstr(sizeof(kData), &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(nullptr, 0, &out);
+  CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
   CborWriteTstr("Measured overflow", &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(nullptr, 0, &out);
   CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
+  EXPECT_EQ(nullptr, CborAllocTstr(6, &out));
+  EXPECT_TRUE(CborOutOverflowed(&out));
+  CborOutInit(nullptr, 0, &out);
+  CborWriteBstr(SIZE_MAX - 10, nullptr, &out);
   CborWriteArray(/*num_elements=*/8368257314, &out);
   EXPECT_TRUE(CborOutOverflowed(&out));
   CborOutInit(nullptr, 0, &out);