Add back `SerializedMessage{Reader,Rewriter}::OnLengthUnchecked()`.
It turns out to be significantly faster after all, the `ScopedLimiter` overhead
is too high.
This time it includes a check there against the source length, to make
processing untrusted input safer. In `OnLengthDelimited()` the analogous check
is performed by `ScopedLimiter` with the `exact` parameter.
Simplify `ScopedLimiter` constructor: use `LimitingReader::max_length()` and
`LimitingReader::set_max_length()` instead of converting the length to the
position first (which may overflow) and then working in terms of the position.
This delegates an overflow check to `set_max_length()`.
Tweak `Limiting{Reader,Writer,BackwardWriter}` failure messages to report
lengths instead of positions when appropriate.
PiperOrigin-RevId: 811779742
diff --git a/riegeli/bytes/limiting_backward_writer.cc b/riegeli/bytes/limiting_backward_writer.cc
index 35d1e82..20a33e7 100644
--- a/riegeli/bytes/limiting_backward_writer.cc
+++ b/riegeli/bytes/limiting_backward_writer.cc
@@ -85,7 +85,8 @@
// Do not call `Fail()` because `AnnotateStatusImpl()` synchronizes the
// buffer again.
FailWithoutAnnotation(dest.AnnotateStatus(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected ", max_pos_))));
+ absl::StrCat("Not enough data: expected ", max_pos(), " or ",
+ max_length(), " more"))));
}
BackwardWriter::Done();
}
@@ -106,9 +107,9 @@
inline void LimitingBackwardWriterBase::FailLengthOverflow(
Position max_length) {
- Fail(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected ", pos(), " + ", max_length,
- " which overflows the BackwardWriter position")));
+ Fail(absl::InvalidArgumentError(absl::StrCat(
+ "Not enough data: expected ", max_length,
+ " more, which overflows the BackwardWriter position range")));
}
absl::Status LimitingBackwardWriterBase::AnnotateStatusImpl(
diff --git a/riegeli/bytes/limiting_reader.cc b/riegeli/bytes/limiting_reader.cc
index 233ea2d..2feba7d 100644
--- a/riegeli/bytes/limiting_reader.cc
+++ b/riegeli/bytes/limiting_reader.cc
@@ -103,19 +103,35 @@
inline bool LimitingReaderBase::FailNotEnough() {
return Fail(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected at least ", max_pos_)));
+ max_pos() == std::numeric_limits<Position>::max()
+ ? "Not enough data: expected impossibly much"
+ : absl::StrCat("Not enough data: expected at least ", max_pos(),
+ " or ", max_length(), " more")));
}
-inline void LimitingReaderBase::FailNotEnoughEarly(Position expected) {
+inline void LimitingReaderBase::FailNotEnoughAtPos(Position expected_pos) {
Fail(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected at least ", expected,
- ", will have at most ", max_pos_)));
+ absl::StrCat("Not enough data: expected at least ", expected_pos,
+ ", will have at most ", max_pos())));
+}
+
+inline void LimitingReaderBase::FailNotEnoughAtLength(
+ Position expected_length) {
+ Fail(absl::InvalidArgumentError(
+ absl::StrCat("Not enough data: expected at least ", expected_length,
+ " more, will have at most ", max_length(), " more")));
+}
+
+inline void LimitingReaderBase::FailNotEnoughAtEnd() {
+ Fail(absl::InvalidArgumentError(absl::StrCat(
+ "Not enough data: expected impossibly much, will have at most ",
+ max_pos())));
}
inline void LimitingReaderBase::FailLengthOverflow(Position max_length) {
Fail(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected at least ", pos(), " + ",
- max_length, " which overflows the Reader position")));
+ absl::StrCat("Not enough data: expected at least ", max_length,
+ " more, which overflows the Reader position range")));
}
inline void LimitingReaderBase::FailPositionLimitExceeded() {
@@ -376,27 +392,21 @@
fail_if_longer_(options.fail_if_longer()) {
if (options.max_pos() != std::nullopt) {
if (ABSL_PREDICT_FALSE(*options.max_pos() > reader_->max_pos())) {
- if (options.exact()) reader_->FailNotEnoughEarly(*options.max_pos());
+ if (options.exact()) reader_->FailNotEnoughAtPos(*options.max_pos());
} else {
reader_->set_max_pos(*options.max_pos());
}
} else if (options.max_length() != std::nullopt) {
- if (ABSL_PREDICT_FALSE(*options.max_length() >
- std::numeric_limits<Position>::max() -
- reader_->pos())) {
- if (options.exact()) reader_->FailLengthOverflow(*options.max_length());
+ if (ABSL_PREDICT_FALSE(*options.max_length() > reader_->max_length())) {
+ if (options.exact())
+ reader_->FailNotEnoughAtLength(*options.max_length());
} else {
- const Position max_pos = reader_->pos() + *options.max_length();
- if (ABSL_PREDICT_FALSE(max_pos > reader_->max_pos())) {
- if (options.exact()) reader_->FailNotEnoughEarly(max_pos);
- } else {
- reader_->set_max_pos(max_pos);
- }
+ reader_->set_max_length(*options.max_length());
}
} else if (ABSL_PREDICT_FALSE(reader_->max_pos() <
std::numeric_limits<Position>::max() &&
options.exact())) {
- reader_->FailNotEnoughEarly(std::numeric_limits<Position>::max());
+ reader_->FailNotEnoughAtPos(std::numeric_limits<Position>::max());
}
reader_->exact_ |= options.exact();
}
diff --git a/riegeli/bytes/limiting_reader.h b/riegeli/bytes/limiting_reader.h
index b0d8974..436b365 100644
--- a/riegeli/bytes/limiting_reader.h
+++ b/riegeli/bytes/limiting_reader.h
@@ -254,13 +254,16 @@
std::unique_ptr<Reader> NewReaderImpl(Position initial_pos) override;
private:
- // For `FailNotEnoughEarly()`, `FailLengthOverflow()`, and
+ // For `FailNotEnoughAtPos()`, `FailNotEnoughAtLength()`,
+ // `FailNotEnoughAtEnd()`, `FailLengthOverflow()`, and
// `FailPositionLimitExceeded()`.
friend class ScopedLimiter;
bool CheckEnough();
ABSL_ATTRIBUTE_COLD bool FailNotEnough();
- ABSL_ATTRIBUTE_COLD void FailNotEnoughEarly(Position expected);
+ ABSL_ATTRIBUTE_COLD void FailNotEnoughAtPos(Position expected_pos);
+ ABSL_ATTRIBUTE_COLD void FailNotEnoughAtLength(Position expected_length);
+ ABSL_ATTRIBUTE_COLD void FailNotEnoughAtEnd();
ABSL_ATTRIBUTE_COLD void FailLengthOverflow(Position max_length);
ABSL_ATTRIBUTE_COLD void FailPositionLimitExceeded();
diff --git a/riegeli/bytes/limiting_writer.cc b/riegeli/bytes/limiting_writer.cc
index eaa30a5..117e9f3 100644
--- a/riegeli/bytes/limiting_writer.cc
+++ b/riegeli/bytes/limiting_writer.cc
@@ -84,7 +84,8 @@
// Do not call `Fail()` because `AnnotateStatusImpl()` synchronizes the
// buffer again.
FailWithoutAnnotation(dest.AnnotateStatus(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected ", max_pos_))));
+ absl::StrCat("Not enough data: expected ", max_pos(), " or ",
+ max_length(), " more"))));
}
Writer::Done();
}
@@ -105,8 +106,8 @@
inline void LimitingWriterBase::FailLengthOverflow(Position max_length) {
Fail(absl::InvalidArgumentError(
- absl::StrCat("Not enough data: expected ", pos(), " + ", max_length,
- " which overflows the Writer position")));
+ absl::StrCat("Not enough data: expected ", max_length,
+ "more, which overflows the Writer position range")));
}
absl::Status LimitingWriterBase::AnnotateStatusImpl(absl::Status status) {
diff --git a/riegeli/messages/BUILD b/riegeli/messages/BUILD
index 4b52bbf..451992f 100644
--- a/riegeli/messages/BUILD
+++ b/riegeli/messages/BUILD
@@ -174,7 +174,6 @@
":message_wire_format",
":parse_message",
"//riegeli/base:any",
- "//riegeli/base:arithmetic",
"//riegeli/base:assert",
"//riegeli/base:chain",
"//riegeli/base:global",
diff --git a/riegeli/messages/serialized_message_reader.cc b/riegeli/messages/serialized_message_reader.cc
index 8ae37c7..9dd64f9 100644
--- a/riegeli/messages/serialized_message_reader.cc
+++ b/riegeli/messages/serialized_message_reader.cc
@@ -32,7 +32,6 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "riegeli/base/any.h"
-#include "riegeli/base/arithmetic.h"
#include "riegeli/base/assert.h"
#include "riegeli/base/chain.h"
#include "riegeli/base/type_erased_ref.h"
@@ -344,16 +343,15 @@
std::function<absl::Status(absl::string_view value, LimitingReaderBase& src,
TypeErasedRef context)>
action) {
- OnLengthDelimited(
- field_path, [action = std::move(action)](LimitingReaderBase& src,
- TypeErasedRef context) {
- absl::string_view value;
- if (ABSL_PREDICT_FALSE(
- !src.Read(IntCast<size_t>(src.max_length()), value))) {
- return ReadLengthDelimitedError(src);
- }
- return action(value, src, context);
- });
+ OnLengthUnchecked(field_path, [action = std::move(action)](
+ size_t length, LimitingReaderBase& src,
+ TypeErasedRef context) {
+ absl::string_view value;
+ if (ABSL_PREDICT_FALSE(!src.Read(length, value))) {
+ return ReadLengthDelimitedError(src);
+ }
+ return action(value, src, context);
+ });
}
void SerializedMessageReaderBase::OnString(
@@ -361,16 +359,15 @@
std::function<absl::Status(std::string&& value, LimitingReaderBase& src,
TypeErasedRef context)>
action) {
- OnLengthDelimited(
- field_path, [action = std::move(action)](LimitingReaderBase& src,
- TypeErasedRef context) {
- std::string value;
- if (ABSL_PREDICT_FALSE(
- !src.Read(IntCast<size_t>(src.max_length()), value))) {
- return ReadLengthDelimitedError(src);
- }
- return action(std::move(value), src, context);
- });
+ OnLengthUnchecked(field_path, [action = std::move(action)](
+ size_t length, LimitingReaderBase& src,
+ TypeErasedRef context) {
+ std::string value;
+ if (ABSL_PREDICT_FALSE(!src.Read(length, value))) {
+ return ReadLengthDelimitedError(src);
+ }
+ return action(std::move(value), src, context);
+ });
}
void SerializedMessageReaderBase::OnChain(
@@ -378,16 +375,15 @@
std::function<absl::Status(Chain&& value, LimitingReaderBase& src,
TypeErasedRef context)>
action) {
- OnLengthDelimited(
- field_path, [action = std::move(action)](LimitingReaderBase& src,
- TypeErasedRef context) {
- Chain value;
- if (ABSL_PREDICT_FALSE(
- !src.Read(IntCast<size_t>(src.max_length()), value))) {
- return ReadLengthDelimitedError(src);
- }
- return action(std::move(value), src, context);
- });
+ OnLengthUnchecked(field_path, [action = std::move(action)](
+ size_t length, LimitingReaderBase& src,
+ TypeErasedRef context) {
+ Chain value;
+ if (ABSL_PREDICT_FALSE(!src.Read(length, value))) {
+ return ReadLengthDelimitedError(src);
+ }
+ return action(std::move(value), src, context);
+ });
}
void SerializedMessageReaderBase::OnCord(
@@ -395,16 +391,15 @@
std::function<absl::Status(absl::Cord&& value, LimitingReaderBase& src,
TypeErasedRef context)>
action) {
- OnLengthDelimited(
- field_path, [action = std::move(action)](LimitingReaderBase& src,
- TypeErasedRef context) {
- absl::Cord value;
- if (ABSL_PREDICT_FALSE(
- !src.Read(IntCast<size_t>(src.max_length()), value))) {
- return ReadLengthDelimitedError(src);
- }
- return action(std::move(value), src, context);
- });
+ OnLengthUnchecked(field_path, [action = std::move(action)](
+ size_t length, LimitingReaderBase& src,
+ TypeErasedRef context) {
+ absl::Cord value;
+ if (ABSL_PREDICT_FALSE(!src.Read(length, value))) {
+ return ReadLengthDelimitedError(src);
+ }
+ return action(std::move(value), src, context);
+ });
}
void SerializedMessageReaderBase::OnLengthDelimited(
@@ -429,6 +424,31 @@
});
}
+void SerializedMessageReaderBase::OnLengthUnchecked(
+ absl::Span<const int> field_path,
+ std::function<absl::Status(size_t length, LimitingReaderBase& src,
+ TypeErasedRef context)>
+ action) {
+ SetAction(
+ field_path, WireType::kLengthDelimited,
+ [action = std::move(action)](LimitingReaderBase& src,
+ TypeErasedRef context) {
+ uint32_t length;
+ if (ABSL_PREDICT_FALSE(
+ !ReadVarint32(src, length) ||
+ length > uint32_t{std::numeric_limits<int32_t>::max()})) {
+ return src.StatusOrAnnotate(absl::InvalidArgumentError(
+ "Could not read a length-delimited field length"));
+ }
+ if (ABSL_PREDICT_FALSE(length > src.max_length())) {
+ return src.StatusOrAnnotate(absl::InvalidArgumentError(absl::StrCat(
+ "Not enough data: expected at least ", length,
+ " more, will have at most ", src.max_length(), " more")));
+ }
+ return action(size_t{length}, src, context);
+ });
+}
+
void SerializedMessageReaderBase::BeforeMessage(
absl::Span<const int> field_path,
std::function<absl::Status(LimitingReaderBase& src, TypeErasedRef context)>
@@ -502,12 +522,13 @@
if (src.IsOwning()) src->SetReadAllHint(true);
LimitingReaderBase::Options options;
if (src->SupportsSize()) {
- // Make `OnLengthDelimited()` safer for processing untrusted input.
- // It will fail early if the stored length exceeds the remaining size.
+ // Make `OnLengthDelimited()` and `OnLengthUnchecked()` safer for processing
+ // untrusted input. They will fail early if the stored length exceeds the
+ // remaining size.
const std::optional<Position> size = src->Size();
if (ABSL_PREDICT_TRUE(size != std::nullopt)) options.set_max_pos(*size);
}
- LimitingReader src_limiting(std::move(src), std::move(options));
+ LimitingReader src_limiting(std::move(src), options);
absl::Status status = ReadRootMessage(src_limiting, context);
if (ABSL_PREDICT_TRUE(status.ok())) src_limiting.VerifyEnd();
if (ABSL_PREDICT_FALSE(!src_limiting.Close())) {
diff --git a/riegeli/messages/serialized_message_reader.h b/riegeli/messages/serialized_message_reader.h
index a5a46b2..6e4be2a 100644
--- a/riegeli/messages/serialized_message_reader.h
+++ b/riegeli/messages/serialized_message_reader.h
@@ -118,6 +118,11 @@
std::function<absl::Status(LimitingReaderBase& src,
TypeErasedRef context)>
action);
+ void OnLengthUnchecked(
+ absl::Span<const int> field_path,
+ std::function<absl::Status(size_t length, LimitingReaderBase& src,
+ TypeErasedRef context)>
+ action);
void BeforeMessage(absl::Span<const int> field_path,
std::function<absl::Status(LimitingReaderBase& src,
TypeErasedRef context)>
@@ -226,7 +231,7 @@
// if the action is invocable with them):
// * parameters specific to the action type
// * `LimitingReaderBase& src` or `Reader& src`
-// (optional; required in `OnOther()`)
+// (optional; required in `OnLengthUnchecked()` and `OnOther()`)
// * `Context& context` (optional; always absent if `Context` is `void`)
//
// An action returns `absl::Status`, non-OK causing an early exit.
@@ -429,6 +434,24 @@
int> = 0>
void OnLengthDelimited(absl::Span<const int> field_path, Action action);
+ // Sets the action to be performed when encountering a length-delimited field
+ // identified by `field_path` of field numbers from the root through
+ // submessages.
+ //
+ // `action` is invoked with `length`, and `src` from which the value will be
+ // read. The first `length` bytes of `src` will contain the field contents.
+ // `action` must read exactly `length` bytes from `src`, unless it fails.
+ // This is unchecked.
+ //
+ // `OnLengthUnchecked()` is more efficient than `OnLengthDelimited()`.
+ //
+ // Precondition: `!field_path.empty()`
+ template <typename Action,
+ std::enable_if_t<serialized_message_internal::IsActionWithSrc<
+ Context, Action, size_t>::value,
+ int> = 0>
+ void OnLengthUnchecked(absl::Span<const int> field_path, Action action);
+
// Sets the action to be performed when encountering a submessage field
// identified by `field_path` of field numbers from the root through
// submessages. An empty `field_path` specified the root message.
@@ -879,12 +902,13 @@
inline void SerializedMessageReader<Context>::OnParsedMessage(
absl::Span<const int> field_path, Action action,
ParseMessageOptions parse_options) {
- SerializedMessageReaderBase::OnLengthDelimited(
- field_path, [action = std::move(action), parse_options](
- LimitingReaderBase& src, TypeErasedRef context) {
+ SerializedMessageReaderBase::OnLengthUnchecked(
+ field_path,
+ [action = std::move(action), parse_options](
+ size_t length, LimitingReaderBase& src, TypeErasedRef context) {
MessageType message;
- if (absl::Status status =
- riegeli::ParseMessage(src, message, parse_options);
+ if (absl::Status status = riegeli::ParseMessageWithLength(
+ src, length, message, parse_options);
ABSL_PREDICT_FALSE(!status.ok())) {
return status;
}
@@ -910,6 +934,22 @@
template <typename Context>
template <typename Action,
+ std::enable_if_t<serialized_message_internal::IsActionWithSrc<
+ Context, Action, size_t>::value,
+ int>>
+inline void SerializedMessageReader<Context>::OnLengthUnchecked(
+ absl::Span<const int> field_path, Action action) {
+ SerializedMessageReaderBase::OnLengthUnchecked(
+ field_path,
+ [action = std::move(action)](size_t length, LimitingReaderBase& src,
+ TypeErasedRef context) {
+ return serialized_message_internal::InvokeActionWithSrc<Context>(
+ src, context, action, length);
+ });
+}
+
+template <typename Context>
+template <typename Action,
std::enable_if_t<serialized_message_internal::IsActionWithOptionalSrc<
Context, Action>::value,
int>>
diff --git a/riegeli/messages/serialized_message_rewriter.h b/riegeli/messages/serialized_message_rewriter.h
index fb57a7f..12bd942 100644
--- a/riegeli/messages/serialized_message_rewriter.h
+++ b/riegeli/messages/serialized_message_rewriter.h
@@ -131,7 +131,8 @@
// Parameters of the actions are as follows (optional parameters are passed
// if the action is invocable with them):
// * parameters specific to the action type
-// * `LimitingReaderBase& src` or `Reader& src` (optional)
+// * `LimitingReaderBase& src` or `Reader& src`
+// (optional; required in `OnLengthUnchecked()`)
// * `SerializedMessageWriter& dest` (optional)
// * `Context& context` (optional; always absent if `Context` is `void`)
@@ -344,6 +345,29 @@
int> = 0>
void OnLengthDelimited(absl::Span<const int> field_path, Action action);
+ // Sets the action to be performed when encountering a length-delimited field
+ // identified by `field_path` of field numbers from the root through
+ // submessages.
+ //
+ // `action` is invoked with `length`, `src` from which the value will be read,
+ // and `dest` positioned between fields. The first `length` bytes of `src`
+ // will contain the field contents. `action` must read exactly `length` bytes
+ // from `src`, unless it fails. This is unchecked.
+ //
+ // The field will not be implicitly copied. `action` can write replacement
+ // fields to `dest`, or do nothing to remove the field.
+ //
+ // `OnLengthUnchecked()` is more efficient than `OnLengthDelimited()`.
+ //
+ // Precondition: `!field_path.empty()`
+ template <
+ typename Action,
+ std::enable_if_t<
+ serialized_message_internal::IsActionWithRequiredSrcAndOptionalDest<
+ Context, Action, size_t>::value,
+ int> = 0>
+ void OnLengthUnchecked(absl::Span<const int> field_path, Action action);
+
// Sets the action to be performed when encountering a submessage field
// identified by `field_path` of field numbers from the root through
// submessages. An empty `field_path` specified the root message.
@@ -916,12 +940,13 @@
inline void SerializedMessageRewriter<Context>::OnParsedMessage(
absl::Span<const int> field_path, Action action,
ParseMessageOptions parse_options) {
- message_reader_.OnLengthDelimited(
+ message_reader_.OnLengthUnchecked(
field_path, [action = std::move(action), parse_options](
- LimitingReaderBase& src, MessageReaderContext& context) {
+ size_t length, LimitingReaderBase& src,
+ MessageReaderContext& context) {
MessageType message;
- if (absl::Status status =
- riegeli::ParseMessage(src, message, parse_options);
+ if (absl::Status status = riegeli::ParseMessageWithLength(
+ src, length, message, parse_options);
ABSL_PREDICT_FALSE(!status.ok())) {
return status;
}
@@ -960,6 +985,29 @@
template <typename Context>
template <
typename Action,
+ std::enable_if_t<
+ serialized_message_internal::IsActionWithRequiredSrcAndOptionalDest<
+ Context, Action, size_t>::value,
+ int>>
+inline void SerializedMessageRewriter<Context>::OnLengthUnchecked(
+ absl::Span<const int> field_path, Action action) {
+ message_reader_.OnLengthUnchecked(
+ field_path,
+ [action = std::move(action)](size_t length, LimitingReaderBase& src,
+ MessageReaderContext& context) {
+ if (absl::Status status = context.CommitUnchanged(src, length);
+ ABSL_PREDICT_FALSE(!status.ok())) {
+ return status;
+ }
+ return serialized_message_internal::InvokeActionWithSrcAndDest<Context>(
+ src, context.message_writer(), context.message_rewriter_context(),
+ action, length);
+ });
+}
+
+template <typename Context>
+template <
+ typename Action,
std::enable_if_t<serialized_message_internal::
IsActionWithOptionalSrcAndDest<Context, Action>::value,
int>>