blob: 9657b02b5febb2469a8bebdd268b13c46a8c0d52 [file] [log] [blame]
// Copyright 2021 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
package dev.pigweed.pw_rpc;
import com.google.common.flogger.FluentLogger;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import dev.pigweed.pw.rpc.internal.Packet.PacketType;
import dev.pigweed.pw.rpc.internal.Packet.RpcPacket;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* A client for a pw_rpc server. Invokes RPCs through a MethodClient and handles RPC responses
* through the processPacket function.
*/
public class Client {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final Map<Integer, Channel> channels;
private final Map<Integer, Service> services;
private final Map<PendingRpc, MethodClient> methodClients;
private final RpcManager rpcs;
private final Function<PendingRpc, StreamObserver<MessageLite>> defaultObserverFactory;
/**
* Creates a new RPC client.
*
* @param channels supported channels, which are used to send requests to the server
* @param services which RPC services this client supports; used to handle encoding and decoding
*/
private Client(List<Channel> channels,
List<Service> services,
Function<PendingRpc, StreamObserver<MessageLite>> defaultObserverFactory) {
this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c));
this.services = services.stream().collect(Collectors.toMap(Service::id, s -> s));
this.methodClients = new HashMap<>();
this.rpcs = new RpcManager();
this.defaultObserverFactory = defaultObserverFactory;
}
/**
* Creates a new pw_rpc client.
*
* @param channels the set of channels for the client to send requests over
* @param services the services to support on this client
* @param defaultObserverFactory function that creates a default observer for each RPC
* @return the new pw.rpc.Client
*/
public static Client create(List<Channel> channels,
List<Service> services,
Function<PendingRpc, StreamObserver<MessageLite>> defaultObserverFactory) {
return new Client(channels, services, defaultObserverFactory);
}
/** Creates a new pw_rpc client that logs responses when no observer is provided to calls. */
public static Client create(List<Channel> channels, List<Service> services) {
return create(channels, services, (rpc) -> new StreamObserver<MessageLite>() {
@Override
public void onNext(MessageLite value) {
logger.atFine().log("%s received response: %s", rpc, value);
}
@Override
public void onCompleted(Status status) {
logger.atInfo().log("%s completed with status %s", rpc, status);
}
@Override
public void onError(Status status) {
logger.atWarning().log("%s terminated with error %s", rpc, status);
}
});
}
/**
* Returns a MethodClient with the given name for the provided channelID
*
* @param channelId the ID for the channel through which to invoke the RPC
* @param fullMethodName the method name as "package.Service.Method" or "package.Service/Method"
*/
public MethodClient method(int channelId, String fullMethodName) {
for (char delimiter : new char[] {'/', '.'}) {
int index = fullMethodName.lastIndexOf(delimiter);
if (index != -1) {
return method(
channelId, fullMethodName.substring(0, index), fullMethodName.substring(index + 1));
}
}
throw new IllegalArgumentException("Invalid method name '" + fullMethodName
+ "'; does not match required package.Service/Method format");
}
/**
* Returns a MethodClient on the provided channel using separate arguments for "package.Service"
* and "Method".
*/
public MethodClient method(int channelId, String fullServiceName, String methodName) {
try {
return method(channelId, Ids.calculate(fullServiceName), Ids.calculate(methodName));
} catch (IllegalArgumentException e) {
// Rethrow the exception with the service and method name instead of the ID.
throw new IllegalArgumentException("Unknown RPC " + fullServiceName + '/' + methodName, e);
}
}
/** Returns a MethodClient with the provided service and method IDs. */
public MethodClient method(int channelId, int serviceId, int methodId) {
Channel channel = channels.get(channelId);
if (channel == null) {
throw new IllegalArgumentException("Unknown channel ID " + channelId);
}
Service service = services.get(serviceId);
if (service == null) {
throw new IllegalArgumentException("Unknown service ID " + serviceId);
}
Method method = service.methods().get(methodId);
if (method == null) {
throw new IllegalArgumentException("Unknown method ID " + methodId);
}
PendingRpc rpc = PendingRpc.create(channel, service, method);
if (!methodClients.containsKey(rpc)) {
methodClients.put(rpc, new MethodClient(rpcs, rpc, defaultObserverFactory.apply(rpc)));
}
return methodClients.get(rpc);
}
/**
* Processes a single RPC packet.
*
* @param data a single, binary encoded RPC packet
* @return true if the packet was decoded and processed by this client; returns false for invalid
* packets or packets for a server or unrecognized channel
*/
public boolean processPacket(byte[] data) {
return processPacket(ByteBuffer.wrap(data));
}
public boolean processPacket(ByteBuffer data) {
RpcPacket packet;
try {
packet = RpcPacket.parseFrom(data, ExtensionRegistryLite.getEmptyRegistry());
} catch (InvalidProtocolBufferException e) {
logger.atWarning().withCause(e).log("Failed to decode packet");
return false;
}
if (packet.getChannelId() == 0 || packet.getServiceId() == 0 || packet.getMethodId() == 0) {
logger.atWarning().log("Received corrupt packet with unset IDs");
return false;
}
// Packets for the server use even type values.
if (packet.getTypeValue() % 2 == 0) {
logger.atFine().log("Ignoring %s packet for server", packet.getType().name());
return false;
}
Channel channel = channels.get(packet.getChannelId());
if (channel == null) {
logger.atWarning().log("Received packet for unrecognized channel %d", packet.getChannelId());
return false;
}
PendingRpc rpc = lookupRpc(channel, packet);
if (rpc == null) {
logger.atInfo().log("Ignoring packet for unknown service method");
sendError(channel, packet, Status.NOT_FOUND);
return true; // true since the packet was handled, even though it was invalid.
}
// Any packet type other than SERVER_STREAM indicates that this is the last packet for this RPC.
StreamObserverCall<?, ?> call =
packet.getType().equals(PacketType.SERVER_STREAM) ? rpcs.getPending(rpc) : rpcs.clear(rpc);
if (call == null) {
logger.atInfo().log(
"Ignoring packet for RPC (%s) that isn't pending. Pending RPCs are: %s", rpc, rpcs);
sendError(channel, packet, Status.FAILED_PRECONDITION);
return true;
}
switch (packet.getType()) {
case SERVER_ERROR:
Status status = decodeStatus(packet);
logger.atWarning().log("RPC %s failed with error %s", rpc, status);
call.onError(status);
break;
case RESPONSE:
// Server streaming an unary RPCs include a payload with their response packet.
if (!rpc.method().isServerStreaming()) {
call.onNext(packet.getPayload());
}
call.onCompleted(decodeStatus(packet));
break;
case SERVER_STREAM:
call.onNext(packet.getPayload());
break;
default:
logger.atWarning().log(
"Unexpected PacketType %d for RPC %s", packet.getType().getNumber(), rpc);
}
return true;
}
private static void sendError(Channel channel, RpcPacket packet, Status status) {
try {
channel.send(Packets.error(packet, status));
} catch (ChannelOutputException e) {
logger.atWarning().withCause(e).log("Failed to send error packet");
}
}
@Nullable
private PendingRpc lookupRpc(Channel channel, RpcPacket packet) {
Service service = services.get(packet.getServiceId());
if (service != null) {
Method method = service.methods().get(packet.getMethodId());
if (method != null) {
return PendingRpc.create(channel, service, method);
}
}
return null;
}
private static Status decodeStatus(RpcPacket packet) {
Status status = Status.fromCode(packet.getStatus());
if (status == null) {
logger.atWarning().log(
"Illegal status code %d in packet; using Status.UNKNOWN ", packet.getStatus());
return Status.UNKNOWN;
}
return status;
}
}