blob: 6670dcf0240f93e0f9ad7679beba6914b107c5b4 [file] [log] [blame]
// Copyright 2021 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#include "pw_protobuf/stream_decoder.h"
#include "pw_assert/check.h"
#include "pw_status/status_with_size.h"
#include "pw_status/try.h"
#include "pw_varint/varint.h"
namespace pw::protobuf {
Status StreamDecoder::BytesReader::DoSeek(ssize_t offset, Whence origin) {
PW_TRY(status_);
ssize_t absolute_position = std::numeric_limits<size_t>::max();
// Convert from the position within the bytes field to the position within the
// proto stream.
switch (origin) {
case Whence::kBeginning:
absolute_position = start_offset_ + offset;
break;
case Whence::kCurrent:
absolute_position = decoder_.reader_.Tell() + offset;
break;
case Whence::kEnd:
absolute_position = end_offset_ + offset;
break;
}
if (absolute_position < 0) {
return Status::InvalidArgument();
}
if (static_cast<size_t>(absolute_position) < start_offset_ ||
static_cast<size_t>(absolute_position) >= end_offset_) {
return Status::OutOfRange();
}
return decoder_.reader_.Seek(absolute_position, Whence::kBeginning);
}
StatusWithSize StreamDecoder::BytesReader::DoRead(ByteSpan destination) {
if (!status_.ok()) {
return StatusWithSize(status_, 0);
}
// Bound the read buffer to the size of the bytes field.
size_t max_length = end_offset_ - decoder_.reader_.Tell();
if (destination.size() > max_length) {
destination = destination.first(max_length);
}
pw::Result<ByteSpan> result = decoder_.reader_.Read(destination);
if (!result.ok()) {
return StatusWithSize(result.status(), 0);
}
return StatusWithSize(result.value().size());
}
StreamDecoder::~StreamDecoder() {
if (parent_ != nullptr) {
parent_->CloseNestedDecoder(*this);
}
}
Status StreamDecoder::Next() {
PW_CHECK(!nested_reader_open_,
"Cannot use parent decoder while a nested one is open");
PW_TRY(status_);
if (!field_consumed_) {
PW_TRY(SkipField());
}
if (reader_.Tell() >= stream_bounds_.high) {
return Status::OutOfRange();
}
status_ = ReadFieldKey();
return status_;
}
Result<int32_t> StreamDecoder::ReadInt32() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
int64_t signed_value = static_cast<int64_t>(varint);
if (signed_value > std::numeric_limits<int32_t>::max() ||
signed_value < std::numeric_limits<int32_t>::min()) {
return Status::OutOfRange();
}
return signed_value;
}
Result<uint32_t> StreamDecoder::ReadUint32() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
if (varint > std::numeric_limits<uint32_t>::max()) {
return Status::OutOfRange();
}
return varint;
}
Result<int64_t> StreamDecoder::ReadInt64() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
return varint;
}
Result<int32_t> StreamDecoder::ReadSint32() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
int64_t signed_value = varint::ZigZagDecode(varint);
if (signed_value > std::numeric_limits<int32_t>::max() ||
signed_value < std::numeric_limits<int32_t>::min()) {
return Status::OutOfRange();
}
return signed_value;
}
Result<int64_t> StreamDecoder::ReadSint64() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
return varint::ZigZagDecode(varint);
}
Result<bool> StreamDecoder::ReadBool() {
uint64_t varint = 0;
PW_TRY(ReadVarintField(&varint));
return varint;
}
StreamDecoder::BytesReader StreamDecoder::GetBytesReader() {
Status status = CheckOkToRead(WireType::kDelimited);
if (reader_.ConservativeReadLimit() < delimited_field_size_) {
status.Update(Status::DataLoss());
}
nested_reader_open_ = true;
if (!status.ok()) {
return BytesReader(*this, status);
}
size_t low = reader_.Tell();
size_t high = low + delimited_field_size_;
return BytesReader(*this, low, high);
}
StreamDecoder StreamDecoder::GetNestedDecoder() {
Status status = CheckOkToRead(WireType::kDelimited);
if (reader_.ConservativeReadLimit() < delimited_field_size_) {
status.Update(Status::DataLoss());
}
nested_reader_open_ = true;
if (!status.ok()) {
return StreamDecoder(reader_, this, status);
}
size_t low = reader_.Tell();
size_t high = low + delimited_field_size_;
return StreamDecoder(reader_, this, low, high);
}
void StreamDecoder::CloseBytesReader(BytesReader& reader) {
status_ = reader.status_;
if (status_.ok()) {
// Advance the stream to the end of the bytes field.
PW_CHECK(reader_.Seek(reader.end_offset_).ok());
}
field_consumed_ = true;
nested_reader_open_ = false;
}
void StreamDecoder::CloseNestedDecoder(StreamDecoder& nested) {
PW_CHECK_PTR_EQ(nested.parent_, this);
nested.nested_reader_open_ = true;
nested.parent_ = nullptr;
status_ = nested.status_;
if (status_.ok()) {
// Advance the stream to the end of the nested message field.
PW_CHECK(reader_.Seek(nested.stream_bounds_.high).ok());
}
field_consumed_ = true;
nested_reader_open_ = false;
}
Status StreamDecoder::ReadFieldKey() {
PW_DCHECK(field_consumed_);
uint64_t varint = 0;
if (StatusWithSize sws = ReadVarint(&varint); !sws.ok()) {
return sws.status();
}
if (!FieldKey::IsValidKey(varint)) {
return Status::DataLoss();
}
current_field_ = FieldKey(varint);
if (current_field_.wire_type() == WireType::kDelimited) {
// Read the length varint of length-delimited fields immediately to simplify
// later processing of the field.
if (StatusWithSize sws = ReadVarint(&varint); !sws.ok()) {
return sws.status();
}
if (varint > std::numeric_limits<uint32_t>::max()) {
return Status::DataLoss();
}
delimited_field_size_ = varint;
delimited_field_offset_ = reader_.Tell();
}
field_consumed_ = false;
return OkStatus();
}
Result<StreamDecoder::Bounds> StreamDecoder::GetLengthDelimitedPayloadBounds() {
PW_TRY(CheckOkToRead(WireType::kDelimited));
return StreamDecoder::Bounds{delimited_field_offset_,
delimited_field_size_ + delimited_field_offset_};
}
// Consumes the current protobuf field, advancing the stream to the key of the
// next field (if one exists).
Status StreamDecoder::SkipField() {
PW_DCHECK(!field_consumed_);
size_t bytes_to_skip = 0;
uint64_t value = 0;
switch (current_field_.wire_type()) {
case WireType::kVarint:
// Consume the varint field; nothing more to skip afterward.
PW_TRY(ReadVarint(&value));
break;
case WireType::kDelimited:
bytes_to_skip = delimited_field_size_;
break;
case WireType::kFixed32:
bytes_to_skip = sizeof(uint32_t);
break;
case WireType::kFixed64:
bytes_to_skip = sizeof(uint64_t);
break;
}
if (bytes_to_skip > 0) {
// Check if the stream has the field available. If not, report it as a
// DATA_LOSS since the proto is invalid (as opposed to OUT_OF_BOUNDS if we
// just tried to seek beyond the end).
if (reader_.ConservativeReadLimit() < bytes_to_skip) {
status_ = Status::DataLoss();
return status_;
}
PW_TRY(reader_.Seek(bytes_to_skip, stream::Stream::kCurrent));
}
field_consumed_ = true;
return OkStatus();
}
Status StreamDecoder::ReadVarintField(uint64_t* out) {
PW_TRY(CheckOkToRead(WireType::kVarint));
uint64_t value;
StatusWithSize sws = ReadVarint(&value);
if (sws.IsOutOfRange()) {
// Out of range indicates the end of the stream. As a value is expected
// here, report it as a data loss and terminate the decode operation.
status_ = Status::DataLoss();
return status_;
}
PW_TRY(sws);
field_consumed_ = true;
*out = value;
return OkStatus();
}
Status StreamDecoder::ReadFixedField(std::span<std::byte> out) {
WireType expected_wire_type =
out.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
PW_TRY(CheckOkToRead(expected_wire_type));
if (reader_.ConservativeReadLimit() < out.size()) {
status_ = Status::DataLoss();
return status_;
}
PW_TRY(reader_.Read(out));
field_consumed_ = true;
return OkStatus();
}
StatusWithSize StreamDecoder::ReadDelimitedField(std::span<std::byte> out) {
if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
return StatusWithSize(status, 0);
}
if (reader_.ConservativeReadLimit() < delimited_field_size_) {
status_ = Status::DataLoss();
return StatusWithSize(status_, 0);
}
if (out.size() < delimited_field_size_) {
// Value can't fit into the provided buffer. Don't advance the cursor so
// that the field can be re-read with a larger buffer or through the stream
// API.
return StatusWithSize::ResourceExhausted();
}
Result<ByteSpan> result = reader_.Read(out);
if (!result.ok()) {
return StatusWithSize(result.status(), 0);
}
field_consumed_ = true;
return StatusWithSize(result.value().size());
}
StatusWithSize StreamDecoder::ReadVarint(uint64_t* output) {
uint64_t value = 0;
size_t count = 0;
while (true) {
if (count >= varint::kMaxVarint64SizeBytes) {
// Varint can't fit a uint64_t.
return StatusWithSize::OutOfRange();
}
std::byte b;
if (auto result = reader_.Read(std::span(&b, 1)); !result.ok()) {
return StatusWithSize(result.status(), 0);
}
value |= static_cast<uint64_t>(b & std::byte(0x7f)) << (7 * count);
++count;
if ((b & std::byte(0x80)) == std::byte(0)) {
break;
}
}
*output = value;
return StatusWithSize(count);
}
Status StreamDecoder::CheckOkToRead(WireType type) {
PW_CHECK(!nested_reader_open_,
"Cannot read from a decoder while a nested decoder is open");
PW_CHECK(
!field_consumed_,
"Attempting to read from protobuf decoder without first calling Next()");
// Attempting to read the wrong type is typically a programmer error; however,
// it could also occur due to data corruption. As we don't want to crash on
// bad data, return NOT_FOUND here to distinguish it from other corruption
// cases.
if (current_field_.wire_type() != type) {
status_ = Status::NotFound();
}
return status_;
}
} // namespace pw::protobuf