Merge pull request #8143 from JamesNK/jamesnk/messageadapter
Optimize MapField serialization by removing MessageAdapter
diff --git a/csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs b/csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs
index d8cdee0..d4c63dc 100644
--- a/csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs
+++ b/csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs
@@ -611,6 +611,32 @@
Assert.IsTrue(input.IsAtEnd);
}
+ [Test]
+ public void AddEntriesFrom_CodedInputStream_MissingKey()
+ {
+ // map will have string key and string value
+ var keyTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
+ var valueTag = WireFormat.MakeTag(2, WireFormat.WireType.LengthDelimited);
+
+ var memoryStream = new MemoryStream();
+ var output = new CodedOutputStream(memoryStream);
+ output.WriteLength(11); // total of valueTag + value
+ output.WriteTag(valueTag);
+ output.WriteString("the_value");
+ output.Flush();
+
+ Console.WriteLine(BitConverter.ToString(memoryStream.ToArray()));
+
+ var field = new MapField<string, string>();
+ var mapCodec = new MapField<string, string>.Codec(FieldCodec.ForString(keyTag, ""), FieldCodec.ForString(valueTag, ""), 10);
+ var input = new CodedInputStream(memoryStream.ToArray());
+
+ field.AddEntriesFrom(input, mapCodec);
+ CollectionAssert.AreEquivalent(new[] { "" }, field.Keys);
+ CollectionAssert.AreEquivalent(new[] { "the_value" }, field.Values);
+ Assert.IsTrue(input.IsAtEnd);
+ }
+
#if !NET35
[Test]
public void IDictionaryKeys_Equals_IReadOnlyDictionaryKeys()
diff --git a/csharp/src/Google.Protobuf/Collections/MapField.cs b/csharp/src/Google.Protobuf/Collections/MapField.cs
index d60ebc5..6b7d0f1 100644
--- a/csharp/src/Google.Protobuf/Collections/MapField.cs
+++ b/csharp/src/Google.Protobuf/Collections/MapField.cs
@@ -448,12 +448,10 @@
[SecuritySafeCritical]
public void AddEntriesFrom(ref ParseContext ctx, Codec codec)
{
- var adapter = new Codec.MessageAdapter(codec);
do
{
- adapter.Reset();
- ctx.ReadMessage(adapter);
- this[adapter.Key] = adapter.Value;
+ KeyValuePair<TKey, TValue> entry = ParsingPrimitivesMessages.ReadMapEntry(ref ctx, codec);
+ this[entry.Key] = entry.Value;
} while (ParsingPrimitives.MaybeConsumeTag(ref ctx.buffer, ref ctx.state, codec.MapTag));
}
@@ -485,13 +483,13 @@
[SecuritySafeCritical]
public void WriteTo(ref WriteContext ctx, Codec codec)
{
- var message = new Codec.MessageAdapter(codec);
foreach (var entry in list)
{
- message.Key = entry.Key;
- message.Value = entry.Value;
ctx.WriteTag(codec.MapTag);
- ctx.WriteMessage(message);
+
+ WritingPrimitives.WriteLength(ref ctx.buffer, ref ctx.state, CalculateEntrySize(codec, entry));
+ codec.KeyCodec.WriteTagAndValue(ref ctx, entry.Key);
+ codec.ValueCodec.WriteTagAndValue(ref ctx, entry.Value);
}
}
@@ -506,18 +504,22 @@
{
return 0;
}
- var message = new Codec.MessageAdapter(codec);
int size = 0;
foreach (var entry in list)
{
- message.Key = entry.Key;
- message.Value = entry.Value;
+ int entrySize = CalculateEntrySize(codec, entry);
+
size += CodedOutputStream.ComputeRawVarint32Size(codec.MapTag);
- size += CodedOutputStream.ComputeMessageSize(message);
+ size += CodedOutputStream.ComputeLengthSize(entrySize) + entrySize;
}
return size;
}
+ private static int CalculateEntrySize(Codec codec, KeyValuePair<TKey, TValue> entry)
+ {
+ return codec.KeyCodec.CalculateSizeWithTag(entry.Key) + codec.ValueCodec.CalculateSizeWithTag(entry.Value);
+ }
+
/// <summary>
/// Returns a string representation of this repeated field, in the same
/// way as it would be represented by the default JSON formatter.
@@ -655,100 +657,19 @@
}
/// <summary>
- /// The tag used in the enclosing message to indicate map entries.
+ /// The key codec.
/// </summary>
- internal uint MapTag { get { return mapTag; } }
+ internal FieldCodec<TKey> KeyCodec => keyCodec;
/// <summary>
- /// A mutable message class, used for parsing and serializing. This
- /// delegates the work to a codec, but implements the <see cref="IMessage"/> interface
- /// for interop with <see cref="CodedInputStream"/> and <see cref="CodedOutputStream"/>.
- /// This is nested inside Codec as it's tightly coupled to the associated codec,
- /// and it's simpler if it has direct access to all its fields.
+ /// The value codec.
/// </summary>
- internal class MessageAdapter : IMessage, IBufferMessage
- {
- private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };
+ internal FieldCodec<TValue> ValueCodec => valueCodec;
- private readonly Codec codec;
- internal TKey Key { get; set; }
- internal TValue Value { get; set; }
-
- internal MessageAdapter(Codec codec)
- {
- this.codec = codec;
- }
-
- internal void Reset()
- {
- Key = codec.keyCodec.DefaultValue;
- Value = codec.valueCodec.DefaultValue;
- }
-
- public void MergeFrom(CodedInputStream input)
- {
- // Message adapter is an internal class and we know that all the parsing will happen via InternalMergeFrom.
- throw new NotImplementedException();
- }
-
- [SecuritySafeCritical]
- public void InternalMergeFrom(ref ParseContext ctx)
- {
- uint tag;
- while ((tag = ctx.ReadTag()) != 0)
- {
- if (tag == codec.keyCodec.Tag)
- {
- Key = codec.keyCodec.Read(ref ctx);
- }
- else if (tag == codec.valueCodec.Tag)
- {
- Value = codec.valueCodec.Read(ref ctx);
- }
- else
- {
- ParsingPrimitivesMessages.SkipLastField(ref ctx.buffer, ref ctx.state);
- }
- }
-
- // Corner case: a map entry with a key but no value, where the value type is a message.
- // Read it as if we'd seen input with no data (i.e. create a "default" message).
- if (Value == null)
- {
- if (ctx.state.CodedInputStream != null)
- {
- // the decoded message might not support parsing from ParseContext, so
- // we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
- Value = codec.valueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
- }
- else
- {
- ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
- Value = codec.valueCodec.Read(ref zeroLengthCtx);
- }
- }
- }
-
- public void WriteTo(CodedOutputStream output)
- {
- // Message adapter is an internal class and we know that all the writing will happen via InternalWriteTo.
- throw new NotImplementedException();
- }
-
- [SecuritySafeCritical]
- public void InternalWriteTo(ref WriteContext ctx)
- {
- codec.keyCodec.WriteTagAndValue(ref ctx, Key);
- codec.valueCodec.WriteTagAndValue(ref ctx, Value);
- }
-
- public int CalculateSize()
- {
- return codec.keyCodec.CalculateSizeWithTag(Key) + codec.valueCodec.CalculateSizeWithTag(Value);
- }
-
- MessageDescriptor IMessage.Descriptor { get { return null; } }
- }
+ /// <summary>
+ /// The tag used in the enclosing message to indicate map entries.
+ /// </summary>
+ internal uint MapTag => mapTag;
}
private class MapView<T> : ICollection<T>, ICollection
diff --git a/csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs b/csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs
index b7097a2..eabaf96 100644
--- a/csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs
+++ b/csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs
@@ -32,9 +32,11 @@
using System;
using System.Buffers;
+using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Security;
+using Google.Protobuf.Collections;
namespace Google.Protobuf
{
@@ -44,6 +46,8 @@
[SecuritySafeCritical]
internal static class ParsingPrimitivesMessages
{
+ private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };
+
public static void SkipLastField(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
{
if (state.lastTag == 0)
@@ -134,6 +138,65 @@
SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
}
+ public static KeyValuePair<TKey, TValue> ReadMapEntry<TKey, TValue>(ref ParseContext ctx, MapField<TKey, TValue>.Codec codec)
+ {
+ int length = ParsingPrimitives.ParseLength(ref ctx.buffer, ref ctx.state);
+ if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
+ {
+ throw InvalidProtocolBufferException.RecursionLimitExceeded();
+ }
+ int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
+ ++ctx.state.recursionDepth;
+
+ TKey key = codec.KeyCodec.DefaultValue;
+ TValue value = codec.ValueCodec.DefaultValue;
+
+ uint tag;
+ while ((tag = ctx.ReadTag()) != 0)
+ {
+ if (tag == codec.KeyCodec.Tag)
+ {
+ key = codec.KeyCodec.Read(ref ctx);
+ }
+ else if (tag == codec.ValueCodec.Tag)
+ {
+ value = codec.ValueCodec.Read(ref ctx);
+ }
+ else
+ {
+ SkipLastField(ref ctx.buffer, ref ctx.state);
+ }
+ }
+
+ // Corner case: a map entry with a key but no value, where the value type is a message.
+ // Read it as if we'd seen input with no data (i.e. create a "default" message).
+ if (value == null)
+ {
+ if (ctx.state.CodedInputStream != null)
+ {
+ // the decoded message might not support parsing from ParseContext, so
+ // we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
+ value = codec.ValueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
+ }
+ else
+ {
+ ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
+ value = codec.ValueCodec.Read(ref zeroLengthCtx);
+ }
+ }
+
+ CheckReadEndOfStreamTag(ref ctx.state);
+ // Check that we've read exactly as much data as expected.
+ if (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+ {
+ throw InvalidProtocolBufferException.TruncatedMessage();
+ }
+ --ctx.state.recursionDepth;
+ SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
+
+ return new KeyValuePair<TKey, TValue>(key, value);
+ }
+
public static void ReadGroup(ref ParseContext ctx, IMessage message)
{
if (ctx.state.recursionDepth >= ctx.state.recursionLimit)