blob: bd828e94b01b3164095891a1213c68f5ac4f9b52 [file] [log] [blame]
// Copyright 2021 The Pigweed Authors
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
package dev.pigweed.pw_rpc;
import dev.pigweed.pw_rpc.internal.Packet.PacketType;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
* Wraps a StreamObserverMethodClient for use in tests. Provides methods for simulating the server
* interactions with the client.
public class TestClient {
private static final int CHANNEL_ID = 1;
private final Client client;
private final List<RpcPacket> sentPackets = new ArrayList<>();
private final List<RpcPacket> enqueuedPackets = new ArrayList<>();
private int receiveEnqueuedPacketsAfter = 1;
final Map<PacketType, Integer> sentPayloadIndices = new EnumMap<>(PacketType.class);
@Nullable ChannelOutputException channelOutputException = null;
public TestClient(List<Service> services) {
Channel.Output channelOutput = packet -> {
if (channelOutputException == null) {
} else {
throw channelOutputException;
// Process any enqueued packets.
if (receiveEnqueuedPacketsAfter > 1) {
receiveEnqueuedPacketsAfter -= 1;
if (!enqueuedPackets.isEmpty()) {
List<RpcPacket> packetsToProcess = new ArrayList<>(enqueuedPackets);
client = Client.create(ImmutableList.of(new Channel(CHANNEL_ID, channelOutput)), services);
public Client client() {
return client;
* Sets the exception to throw the next time a packet is sent. Set to null to accept the packet
* without errors.
* <p>When Channel.Output throws an exception, TestClient does not store those outgoing packets.
public void setChannelOutputException(@Nullable ChannelOutputException exception) {
this.channelOutputException = exception;
/** Returns all payloads that were sent since the last latestClientStreams call. */
public <T extends MessageLite> List<T> lastClientStreams(Class<T> payloadType) {
return sentPayloads(payloadType, PacketType.CLIENT_STREAM);
/** Simulates receiving SERVER_STREAM packets from the server. */
public void receiveServerStream(String service, String method, MessageLiteOrBuilder... payloads) {
RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build();
for (MessageLiteOrBuilder payload : payloads) {
* Enqueues a SERVER_STREAM packet so that the client receives it after a packet is sent.
* @param afterPackets Wait until this many packets have been sent before the client receives
* these stream packets. The minimum value (and the default) is 1.
public void enqueueServerStream(
String service, String method, int afterPackets, MessageLiteOrBuilder... payloads) {
if (afterPackets < 1) {
throw new IllegalArgumentException("afterPackets must be at least 1");
if (afterPackets != 1 && receiveEnqueuedPacketsAfter != 1) {
throw new AssertionError(
"May only set afterPackets once before enqueued packets are processed");
receiveEnqueuedPacketsAfter = afterPackets;
RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build();
for (MessageLiteOrBuilder payload : payloads) {
/** Simulates receiving a SERVER_ERROR packet from the server. */
public void receiveServerError(String service, String method, Status error) {
processPacket(startPacket(service, method, PacketType.SERVER_ERROR).setStatus(error.code()));
/** Parses sent payloads for the given type of packet. */
private <T extends MessageLite> List<T> sentPayloads(Class<T> payloadType, PacketType type) {
int sentPayloadIndex = sentPayloadIndices.getOrDefault(type, 0);
// Filter only the specified packets.
List<T> newPayloads =
.filter(packet -> packet.getType().equals(type))
.map(p -> parseRequestPayload(payloadType, p))
// Store the index of the last read payload. Could drop the viewed packets instead to reduce
// memory usage, but that probably won't matter in practice.
sentPayloadIndices.put(type, sentPayloadIndex + newPayloads.size());
return newPayloads;
private void processPacket(RpcPacket packet) {
if (!client.processPacket(packet.toByteArray())) {
throw new AssertionError("TestClient failed to process a packet!");
private void processPacket(RpcPacket.Builder packet) {
private static RpcPacket.Builder startPacket(String service, String method, PacketType type) {
return RpcPacket.newBuilder()
private static RpcPacket parsePacket(byte[] packet) {
try {
return RpcPacket.parseFrom(packet);
} catch (InvalidProtocolBufferException e) {
throw new AssertionError("Decoding sent packet failed", e);
private <T extends MessageLite> T parseRequestPayload(Class<T> payloadType, RpcPacket packet) {
try {
return payloadType.cast(Method.decodeProtobuf(
client.method(CHANNEL_ID, packet.getServiceId(), packet.getMethodId()).method().request(),
} catch (InvalidProtocolBufferException e) {
throw new AssertionError("Decoding sent packet payload failed", e);
private MessageLite getMessage(MessageLiteOrBuilder messageOrBuilder) {
if (messageOrBuilder instanceof MessageLite.Builder) {
return ((MessageLite.Builder) messageOrBuilder).build();
if (messageOrBuilder instanceof MessageLite) {
return (MessageLite) messageOrBuilder;
throw new AssertionError("Unexpected MessageLiteOrBuilder class");