pw_rpc: Allow simultaneous calls to one method
Calls already include a call_id that, prior to this CL, was largely
unused. This CL looks up calls based on their ID and uses it to
allow multiple outstanding calls on the same channel+service+method
at once.
Note that we still use all of channel+service+method+call to look
up the record of a given call, even though existing clients should
use IDs only once across a channel (so channel+call IDs would be
enough). In the future, this may be relaxed in order to allow
omitting service+method IDs in response messages.
Bug: 237418397
Change-Id: I220b8cb702baaf56867ad3ccdcdcb77a5b8fa625
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/109077
Reviewed-by: Wyatt Hepler <hepler@google.com>
Pigweed-Auto-Submit: Taylor Cramer <cramertj@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index be2dd32..e579db8 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -47,7 +47,7 @@
return Status::Unavailable();
}
- if (call == nullptr || call->id() != packet.call_id()) {
+ if (call == nullptr) {
// The call for the packet does not exist. If the packet is a server stream
// message, notify the server so that it can kill the stream. Otherwise,
// silently drop the packet (as it would terminate the RPC anyway).
diff --git a/pw_rpc/endpoint.cc b/pw_rpc/endpoint.cc
index fccdf26..c9e40af 100644
--- a/pw_rpc/endpoint.cc
+++ b/pw_rpc/endpoint.cc
@@ -62,9 +62,9 @@
void Endpoint::RegisterCall(Call& call) {
Call* const existing_call = FindCallById(
- call.channel_id_locked(), call.service_id(), call.method_id());
+ call.channel_id_locked(), call.service_id(), call.method_id(), call.id());
- RegisterUniqueCall(call);
+ calls_.push_front(call);
if (existing_call != nullptr) {
// TODO(b/234876851): Ensure call object is locked when calling callback.
@@ -77,11 +77,20 @@
Call* Endpoint::FindCallById(uint32_t channel_id,
uint32_t service_id,
- uint32_t method_id) {
+ uint32_t method_id,
+ uint32_t call_id) {
for (Call& call : calls_) {
if (channel_id == call.channel_id_locked() &&
service_id == call.service_id() && method_id == call.method_id()) {
- return &call;
+ if (call_id == call.id() || call_id == kOpenCallId) {
+ return &call;
+ }
+ if (call.id() == kOpenCallId) {
+ // Calls with ID of `kOpenCallId` were unrequested, and
+ // are updated to have the call ID of the first matching request.
+ call.set_id(call_id);
+ return &call;
+ }
}
}
return nullptr;
diff --git a/pw_rpc/public/pw_rpc/internal/call.h b/pw_rpc/public/pw_rpc/internal/call.h
index 6e72257..4a26a61 100644
--- a/pw_rpc/public/pw_rpc/internal/call.h
+++ b/pw_rpc/public/pw_rpc/internal/call.h
@@ -15,6 +15,7 @@
#include <cassert>
#include <cstddef>
+#include <limits>
#include <utility>
#include "pw_containers/intrusive_list.h"
@@ -40,6 +41,11 @@
class LockedEndpoint;
class Packet;
+// Unrequested RPCs always use this call ID. When a subsequent request
+// or response is sent with a matching channel + service + method,
+// it will match a calls with this ID if one exists.
+constexpr uint32_t kOpenCallId = std::numeric_limits<uint32_t>::max();
+
// Internal RPC Call class. The Call is used to respond to any type of RPC.
// Public classes like ServerWriters inherit from it with private inheritance
// and provide a public API for their use case. The Call's public API is used by
@@ -72,6 +78,8 @@
uint32_t id() const PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { return id_; }
+ void set_id(uint32_t id) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { id_ = id; }
+
uint32_t channel_id() const PW_LOCKS_EXCLUDED(rpc_lock()) {
LockGuard lock(rpc_lock());
return channel_id_locked();
@@ -255,6 +263,7 @@
const bool invoke = on_error_ != nullptr;
// TODO(b/234876851): Ensure on_error_ is properly guarded.
+
rpc_lock().unlock();
if (invoke) {
on_error_(error);
diff --git a/pw_rpc/public/pw_rpc/internal/client_call.h b/pw_rpc/public/pw_rpc/internal/client_call.h
index 2120933..5a23c1c 100644
--- a/pw_rpc/public/pw_rpc/internal/client_call.h
+++ b/pw_rpc/public/pw_rpc/internal/client_call.h
@@ -14,6 +14,7 @@
#pragma once
#include <cstdint>
+#include <mutex>
#include "pw_bytes/span.h"
#include "pw_function/function.h"
@@ -32,6 +33,11 @@
rpc_lock().unlock();
}
+ uint32_t id() PW_LOCKS_EXCLUDED(rpc_lock()) {
+ std::lock_guard lock(rpc_lock());
+ return Call::id();
+ }
+
protected:
constexpr ClientCall() = default;
diff --git a/pw_rpc/public/pw_rpc/internal/endpoint.h b/pw_rpc/public/pw_rpc/internal/endpoint.h
index c8caca5..48074c1 100644
--- a/pw_rpc/public/pw_rpc/internal/endpoint.h
+++ b/pw_rpc/public/pw_rpc/internal/endpoint.h
@@ -96,8 +96,10 @@
// Finds a call object for an ongoing call associated with this packet, if
// any. Returns nullptr if no matching call exists.
Call* FindCall(const Packet& packet) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
- return FindCallById(
- packet.channel_id(), packet.service_id(), packet.method_id());
+ return FindCallById(packet.channel_id(),
+ packet.service_id(),
+ packet.method_id(),
+ packet.call_id());
}
void AbortCallsForService(const Service& service)
@@ -141,8 +143,8 @@
Call* FindCallById(uint32_t channel_id,
uint32_t service_id,
- uint32_t method_id)
- PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
+ uint32_t method_id,
+ uint32_t call_id) PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock());
ChannelList channels_ PW_GUARDED_BY(rpc_lock());
IntrusiveList<Call> calls_ PW_GUARDED_BY(rpc_lock());
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 9411564..9adfa9c 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -18,6 +18,7 @@
#include "pw_containers/intrusive_list.h"
#include "pw_rpc/channel.h"
+#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/channel.h"
#include "pw_rpc/internal/endpoint.h"
#include "pw_rpc/internal/lock.h"
@@ -126,12 +127,8 @@
"streaming RPCs.");
}
- // Unrequested RPCs always use 0 as the call ID. When an actual request is
- // sent, the call will be replaced with its real ID.
- constexpr uint32_t kOpenCallId = 0;
-
return internal::CallContext(
- *this, channel_id, service, method, kOpenCallId);
+ *this, channel_id, service, method, internal::kOpenCallId);
}
std::tuple<Service*, const internal::Method*> FindMethod(
diff --git a/pw_rpc/raw/client_reader_writer_test.cc b/pw_rpc/raw/client_reader_writer_test.cc
index b3220a8..68ff35a 100644
--- a/pw_rpc/raw/client_reader_writer_test.cc
+++ b/pw_rpc/raw/client_reader_writer_test.cc
@@ -231,27 +231,6 @@
EXPECT_TRUE(active_call_2.active());
}
-TEST(RawClientReaderWriter, NewCallCancelsPreviousAndCallsErrorCallback) {
- RawClientTestContext ctx;
-
- Status error;
- RawClientReaderWriter active_call_1 = TestService::TestBidirectionalStreamRpc(
- ctx.client(),
- ctx.channel().id(),
- FailIfOnNextCalled,
- FailIfCalled,
- [&error](Status status) { error = status; });
-
- ASSERT_TRUE(active_call_1.active());
-
- RawClientReaderWriter active_call_2 =
- TestService::TestBidirectionalStreamRpc(ctx.client(), ctx.channel().id());
-
- EXPECT_FALSE(active_call_1.active());
- EXPECT_TRUE(active_call_2.active());
- EXPECT_EQ(error, Status::Cancelled());
-}
-
TEST(RawClientReader, NoClientStream_OutOfScope_SilentlyCloses) {
RawClientTestContext ctx;
@@ -329,5 +308,49 @@
kWriterData);
}
+const char* span_as_cstr(ConstByteSpan span) {
+ return reinterpret_cast<const char*>(span.data());
+}
+
+TEST(RawClientReaderWriter,
+ MultipleCallsToSameMethodOkAndReceiveSeparateResponses) {
+ RawClientTestContext ctx;
+
+ ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
+ ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
+
+ Status error;
+ auto set_error = [&error](Status status) { error.Update(status); };
+ RawClientReaderWriter active_call_1 = TestService::TestBidirectionalStreamRpc(
+ ctx.client(),
+ ctx.channel().id(),
+ [&data_1](ConstByteSpan payload) { data_1 = payload; },
+ FailIfCalled,
+ set_error);
+
+ EXPECT_TRUE(active_call_1.active());
+
+ RawClientReaderWriter active_call_2 = TestService::TestBidirectionalStreamRpc(
+ ctx.client(),
+ ctx.channel().id(),
+ [&data_2](ConstByteSpan payload) { data_2 = payload; },
+ FailIfCalled,
+ set_error);
+
+ EXPECT_TRUE(active_call_1.active());
+ EXPECT_TRUE(active_call_2.active());
+ EXPECT_EQ(error, OkStatus());
+
+ ConstByteSpan message_1 = as_bytes(span("hello_1"));
+ ConstByteSpan message_2 = as_bytes(span("hello_2"));
+
+ ctx.server().SendServerStream<TestService::TestBidirectionalStreamRpc>(
+ message_2, active_call_2.id());
+ EXPECT_STREQ(span_as_cstr(data_2), span_as_cstr(message_2));
+ ctx.server().SendServerStream<TestService::TestBidirectionalStreamRpc>(
+ message_1, active_call_1.id());
+ EXPECT_STREQ(span_as_cstr(data_1), span_as_cstr(message_1));
+}
+
} // namespace
} // namespace pw::rpc
diff --git a/pw_rpc/raw/client_testing.cc b/pw_rpc/raw/client_testing.cc
index 3057dce..310d730 100644
--- a/pw_rpc/raw/client_testing.cc
+++ b/pw_rpc/raw/client_testing.cc
@@ -28,10 +28,11 @@
void FakeServer::CheckProcessPacket(internal::PacketType type,
uint32_t service_id,
uint32_t method_id,
+ std::optional<uint32_t> call_id,
ConstByteSpan payload,
Status status) const {
if (Status process_packet_status =
- ProcessPacket(type, service_id, method_id, payload, status);
+ ProcessPacket(type, service_id, method_id, call_id, payload, status);
!process_packet_status.ok()) {
PW_LOG_CRITICAL("Failed to process packet in pw::rpc::FakeServer");
PW_LOG_CRITICAL(
@@ -50,10 +51,10 @@
Status FakeServer::ProcessPacket(internal::PacketType type,
uint32_t service_id,
uint32_t method_id,
+ std::optional<uint32_t> call_id,
ConstByteSpan payload,
Status status) const {
- uint32_t call_id = 0;
- {
+ if (!call_id.has_value()) {
internal::LockGuard lock(output_.mutex_);
auto view = internal::test::PacketsView(
output_.packets(),
@@ -71,7 +72,7 @@
auto packet_encoding_result =
internal::Packet(
- type, channel_id_, service_id, method_id, call_id, payload, status)
+ type, channel_id_, service_id, method_id, *call_id, payload, status)
.Encode(packet_buffer_);
PW_CHECK_OK(packet_encoding_result.status());
return client_.ProcessPacket(*packet_encoding_result);
diff --git a/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h b/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
index 1dbae7d..1239989 100644
--- a/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
+++ b/pw_rpc/raw/public/pw_rpc/raw/client_reader_writer.h
@@ -1,4 +1,4 @@
-// Copyright 2021 The Pigweed Authors
+// Copyright 2022 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
@@ -38,6 +38,8 @@
using internal::Call::active;
using internal::Call::channel_id;
+ using internal::ClientCall::id;
+
// Functions for setting the callbacks.
using internal::StreamResponseClientCall::set_on_completed;
using internal::StreamResponseClientCall::set_on_error;
diff --git a/pw_rpc/raw/public/pw_rpc/raw/client_testing.h b/pw_rpc/raw/public/pw_rpc/raw/client_testing.h
index 692d1fd..89cefe0 100644
--- a/pw_rpc/raw/public/pw_rpc/raw/client_testing.h
+++ b/pw_rpc/raw/public/pw_rpc/raw/client_testing.h
@@ -45,8 +45,9 @@
template <auto kMethod,
typename = std::enable_if_t<
HasServerStream(internal::MethodInfo<kMethod>::kType)>>
- void SendResponse(Status status) const {
- SendPacket<kMethod>(internal::PacketType::RESPONSE, {}, status);
+ void SendResponse(Status status,
+ std::optional<uint32_t> call_id = std::nullopt) const {
+ SendPacket<kMethod>(internal::PacketType::RESPONSE, {}, status, call_id);
}
// Sends a response packet for a unary or client streaming streaming RPC to
@@ -54,45 +55,54 @@
template <auto kMethod,
typename = std::enable_if_t<
!HasServerStream(internal::MethodInfo<kMethod>::kType)>>
- void SendResponse(ConstByteSpan payload, Status status) const {
- SendPacket<kMethod>(internal::PacketType::RESPONSE, payload, status);
+ void SendResponse(ConstByteSpan payload,
+ Status status,
+ std::optional<uint32_t> call_id = std::nullopt) const {
+ SendPacket<kMethod>(
+ internal::PacketType::RESPONSE, payload, status, call_id);
}
// Sends a stream packet for a server or bidirectional streaming RPC to the
// client.
template <auto kMethod>
- void SendServerStream(ConstByteSpan payload) const {
+ void SendServerStream(ConstByteSpan payload,
+ std::optional<uint32_t> call_id = std::nullopt) const {
static_assert(HasServerStream(internal::MethodInfo<kMethod>::kType),
"Only server and bidirectional streaming methods can receive "
"server stream packets");
- SendPacket<kMethod>(internal::PacketType::SERVER_STREAM, payload);
+ SendPacket<kMethod>(
+ internal::PacketType::SERVER_STREAM, payload, OkStatus(), call_id);
}
// Sends a server error packet to the client.
template <auto kMethod>
- void SendServerError(Status error) const {
- SendPacket<kMethod>(internal::PacketType::SERVER_ERROR, {}, error);
+ void SendServerError(Status error,
+ std::optional<uint32_t> call_id = std::nullopt) const {
+ SendPacket<kMethod>(internal::PacketType::SERVER_ERROR, {}, error, call_id);
}
private:
template <auto kMethod>
void SendPacket(internal::PacketType type,
- ConstByteSpan payload = {},
- Status status = OkStatus()) const {
+ ConstByteSpan payload,
+ Status status,
+ std::optional<uint32_t> call_id) const {
using Info = internal::MethodInfo<kMethod>;
CheckProcessPacket(
- type, Info::kServiceId, Info::kMethodId, payload, status);
+ type, Info::kServiceId, Info::kMethodId, call_id, payload, status);
}
void CheckProcessPacket(internal::PacketType type,
uint32_t service_id,
uint32_t method_id,
+ std::optional<uint32_t> call_id,
ConstByteSpan payload,
Status status) const;
Status ProcessPacket(internal::PacketType type,
uint32_t service_id,
uint32_t method_id,
+ std::optional<uint32_t> call_id,
ConstByteSpan payload,
Status status) const;
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 9ef2c29..270674a 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -38,8 +38,6 @@
Endpoint::ProcessPacket(packet_data, Packet::kServer));
internal::rpc_lock().lock();
- internal::ServerCall* const call =
- static_cast<internal::ServerCall*>(FindCall(packet));
// Verbose log for debugging.
// PW_LOG_DEBUG("RPC server received packet type %u for %u:%08x/%08x",
@@ -72,10 +70,11 @@
return OkStatus(); // OK since the packet was handled.
}
+ internal::ServerCall* const call =
+ static_cast<internal::ServerCall*>(FindCall(packet));
+
switch (packet.type()) {
case PacketType::REQUEST: {
- // If the REQUEST is for an ongoing RPC, the existing call will be
- // cancelled when the new call object is created.
const internal::CallContext context(
*this, channel->id(), *service, *method, packet.call_id());
method->Invoke(context, packet);
@@ -86,7 +85,7 @@
break;
case PacketType::CLIENT_ERROR:
case PacketType::DEPRECATED_CANCEL:
- if (call != nullptr && call->id() == packet.call_id()) {
+ if (call != nullptr) {
call->HandleError(packet.status());
} else {
internal::rpc_lock().unlock();
@@ -125,7 +124,7 @@
void Server::HandleClientStreamPacket(const internal::Packet& packet,
internal::Channel& channel,
internal::ServerCall* call) const {
- if (call == nullptr || call->id() != packet.call_id()) {
+ if (call == nullptr) {
channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
.IgnoreError(); // Errors are logged in Channel::Send.
internal::rpc_lock().unlock();
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 2814d40..e82bd2b 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "pw_assert/check.h"
+#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/method.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/test_method.h"
@@ -67,6 +68,8 @@
static constexpr std::array<TestMethodUnion, 0> methods_ = {};
};
+uint32_t kDefaultCallId = 24601;
+
class BasicServer : public ::testing::Test {
protected:
static constexpr byte kDefaultPayload[] = {
@@ -88,31 +91,36 @@
uint32_t channel_id,
uint32_t service_id,
uint32_t method_id,
- span<const byte> payload = kDefaultPayload,
- Status status = OkStatus()) {
- auto result =
- Packet(type, channel_id, service_id, method_id, 0, payload, status)
- .Encode(request_buffer_);
- EXPECT_EQ(OkStatus(), result.status());
- return result.value_or(ConstByteSpan());
+ uint32_t call_id = kDefaultCallId) {
+ return EncodePacketWithBody(type,
+ channel_id,
+ service_id,
+ method_id,
+ call_id,
+ kDefaultPayload,
+ OkStatus());
}
span<const byte> EncodeCancel(uint32_t channel_id = 1,
uint32_t service_id = 42,
- uint32_t method_id = 100) {
- return EncodePacket(PacketType::CLIENT_ERROR,
- channel_id,
- service_id,
- method_id,
- {},
- Status::Cancelled());
+ uint32_t method_id = 100,
+ uint32_t call_id = kDefaultCallId) {
+ return EncodePacketWithBody(PacketType::CLIENT_ERROR,
+ channel_id,
+ service_id,
+ method_id,
+ call_id,
+ {},
+ Status::Cancelled());
}
template <typename T = ConstByteSpan>
ConstByteSpan PacketForRpc(PacketType type,
Status status = OkStatus(),
- T&& payload = {}) {
- return EncodePacket(type, 1, 42, 100, as_bytes(span(payload)), status);
+ T&& payload = {},
+ uint32_t call_id = kDefaultCallId) {
+ return EncodePacketWithBody(
+ type, 1, 42, 100, call_id, as_bytes(span(payload)), status);
}
RawFakeChannelOutput<2> output_;
@@ -124,6 +132,21 @@
private:
byte request_buffer_[64];
+
+ span<const byte> EncodePacketWithBody(PacketType type,
+ uint32_t channel_id,
+ uint32_t service_id,
+ uint32_t method_id,
+ uint32_t call_id,
+ span<const byte> payload,
+ Status status) {
+ auto result =
+ Packet(
+ type, channel_id, service_id, method_id, call_id, payload, status)
+ .Encode(request_buffer_);
+ EXPECT_EQ(OkStatus(), result.status());
+ return result.value_or(ConstByteSpan());
+ }
};
TEST_F(BasicServer, ProcessPacket_ValidMethodInService1_InvokesMethod) {
@@ -155,11 +178,12 @@
}
TEST_F(BasicServer, UnregisterService_CannotCallMethod) {
+ const uint32_t kCallId = 8675309;
server_.UnregisterService(service_1_, service_42_);
- EXPECT_EQ(
- OkStatus(),
- server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
+ EXPECT_EQ(OkStatus(),
+ server_.ProcessPacket(
+ EncodePacket(PacketType::REQUEST, 1, 1, 100, kCallId)));
const Packet& packet =
static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
@@ -167,6 +191,7 @@
EXPECT_EQ(packet.channel_id(), 1u);
EXPECT_EQ(packet.service_id(), 1u);
EXPECT_EQ(packet.method_id(), 100u);
+ EXPECT_EQ(packet.call_id(), kCallId);
EXPECT_EQ(packet.status(), Status::NotFound());
}
@@ -369,8 +394,11 @@
protected:
BidiMethod() {
internal::rpc_lock().lock();
- internal::CallContext context(
- server_, channels_[0].id(), service_42_, service_42_.method(100), 0);
+ internal::CallContext context(server_,
+ channels_[0].id(),
+ service_42_,
+ service_42_.method(100),
+ kDefaultCallId);
// A local temporary is required since the constructor requires a lock,
// but the *move* constructor takes out the lock.
internal::test::FakeServerReaderWriter responder_temp(
@@ -383,7 +411,7 @@
internal::test::FakeServerReaderWriter responder_;
};
-TEST_F(BidiMethod, DuplicateCall_CancelsExistingThenCallsAgain) {
+TEST_F(BidiMethod, DuplicateCallId_CancelsExistingThenCallsAgain) {
int cancelled = 0;
responder_.set_on_error([&cancelled](Status error) {
if (error.IsCancelled()) {
@@ -401,6 +429,65 @@
EXPECT_EQ(method.invocations(), 1u);
}
+TEST_F(BidiMethod, DuplicateMethodDifferentCallId_NotCancelled) {
+ int cancelled = 0;
+ responder_.set_on_error([&cancelled](Status error) {
+ if (error.IsCancelled()) {
+ cancelled += 1;
+ }
+ });
+
+ const uint32_t kSecondCallId = 1625;
+ EXPECT_EQ(OkStatus(),
+ server_.ProcessPacket(PacketForRpc(
+ PacketType::REQUEST, OkStatus(), {}, kSecondCallId)));
+
+ EXPECT_EQ(cancelled, 0);
+}
+
+const char* span_as_cstr(ConstByteSpan span) {
+ return reinterpret_cast<const char*>(span.data());
+}
+
+TEST_F(BidiMethod, DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse) {
+ const uint32_t kSecondCallId = 1625;
+
+ internal::rpc_lock().lock();
+ internal::test::FakeServerReaderWriter responder_2(
+ internal::CallContext(server_,
+ channels_[0].id(),
+ service_42_,
+ service_42_.method(100),
+ kSecondCallId)
+ .ClaimLocked());
+ internal::rpc_lock().unlock();
+
+ ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
+ responder_.set_on_next(
+ [&data_1](ConstByteSpan payload) { data_1 = payload; });
+
+ ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
+ responder_2.set_on_next(
+ [&data_2](ConstByteSpan payload) { data_2 = payload; });
+
+ const char* kMessage1 = "hello_1";
+ const char* kMessage2 = "hello_2";
+
+ EXPECT_EQ(
+ OkStatus(),
+ server_.ProcessPacket(PacketForRpc(
+ PacketType::CLIENT_STREAM, OkStatus(), "hello_2", kSecondCallId)));
+
+ EXPECT_STREQ(span_as_cstr(data_2), kMessage2);
+
+ EXPECT_EQ(
+ OkStatus(),
+ server_.ProcessPacket(PacketForRpc(
+ PacketType::CLIENT_STREAM, OkStatus(), "hello_1", kDefaultCallId)));
+
+ EXPECT_STREQ(span_as_cstr(data_1), kMessage1);
+}
+
TEST_F(BidiMethod, Cancel_ClosesServerWriter) {
EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
@@ -468,7 +555,47 @@
PacketForRpc(PacketType::CLIENT_STREAM, {}, "hello")));
EXPECT_EQ(output_.total_packets(), 0u);
- EXPECT_STREQ(reinterpret_cast<const char*>(data.data()), "hello");
+ EXPECT_STREQ(span_as_cstr(data), "hello");
+}
+
+TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithOpenId) {
+ ConstByteSpan data = as_bytes(span("?"));
+ responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
+
+ ASSERT_EQ(
+ OkStatus(),
+ server_.ProcessPacket(PacketForRpc(
+ PacketType::CLIENT_STREAM, {}, "hello", internal::kOpenCallId)));
+
+ EXPECT_EQ(output_.total_packets(), 0u);
+ EXPECT_STREQ(span_as_cstr(data), "hello");
+}
+
+TEST_F(BidiMethod, ClientStream_CallsOpenIdOnCallWithDifferentId) {
+ const uint32_t kSecondCallId = 1625;
+ internal::CallContext context(server_,
+ channels_[0].id(),
+ service_42_,
+ service_42_.method(100),
+ internal::kOpenCallId);
+ internal::rpc_lock().lock();
+ auto temp_responder =
+ internal::test::FakeServerReaderWriter(context.ClaimLocked());
+ internal::rpc_lock().unlock();
+ responder_ = std::move(temp_responder);
+
+ ConstByteSpan data = as_bytes(span("?"));
+ responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
+
+ ASSERT_EQ(OkStatus(),
+ server_.ProcessPacket(PacketForRpc(
+ PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
+
+ EXPECT_EQ(output_.total_packets(), 0u);
+ EXPECT_STREQ(span_as_cstr(data), "hello");
+
+ std::lock_guard lock(internal::rpc_lock());
+ EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
}
TEST_F(BidiMethod, UnregsiterService_AbortsActiveCalls) {
@@ -518,8 +645,11 @@
class ServerStreamingMethod : public BasicServer {
protected:
ServerStreamingMethod() {
- internal::CallContext context(
- server_, channels_[0].id(), service_42_, service_42_.method(100), 0);
+ internal::CallContext context(server_,
+ channels_[0].id(),
+ service_42_,
+ service_42_.method(100),
+ kDefaultCallId);
internal::rpc_lock().lock();
internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
internal::rpc_lock().unlock();