blob: e084fbfd7d4705b3ad627cebdb53f3d317cd7c6c [file] [log] [blame]
// Copyright 2021 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
import {Status} from '@pigweed/pw_status';
import {Message} from 'google-protobuf';
import WaitQueue = require('wait-queue');
import {PendingCalls, Rpc} from './rpc_classes';
export type Callback = (a: any) => any;
class RpcError extends Error {
status: Status;
constructor(rpc: Rpc, status: Status) {
let message = '';
if (status === Status.NOT_FOUND) {
message = ': the RPC server does not support this RPC';
} else if (status === Status.DATA_LOSS) {
message = ': an error occurred while decoding the RPC payload';
}
super(`${rpc.method.name} failed with error ${Status[status]}${message}`);
this.status = status;
}
}
class RpcTimeout extends Error {
readonly rpc: Rpc;
readonly timeoutMs: number;
constructor(rpc: Rpc, timeoutMs: number) {
super(`${rpc.method.name} timed out after ${timeoutMs} ms`);
this.rpc = rpc;
this.timeoutMs = timeoutMs;
}
}
/** Represent an in-progress or completed RPC call. */
export class Call {
// Responses ordered by arrival time. Undefined signifies stream completion.
private responseQueue = new WaitQueue<Message | undefined>();
protected responses: Message[] = [];
private rpcs: PendingCalls;
private rpc: Rpc;
private onNext: Callback;
private onCompleted: Callback;
private onError: Callback;
status?: Status;
error?: Status;
callbackException?: Error;
constructor(
rpcs: PendingCalls,
rpc: Rpc,
onNext: Callback,
onCompleted: Callback,
onError: Callback
) {
this.rpcs = rpcs;
this.rpc = rpc;
this.onNext = onNext;
this.onCompleted = onCompleted;
this.onError = onError;
}
/* Calls the RPC. This must be called immediately after construction. */
invoke(request?: Message, ignoreErrors = false): void {
const previous = this.rpcs.sendRequest(
this.rpc,
this,
ignoreErrors,
request
);
if (previous !== undefined && !previous.completed) {
previous.handleError(Status.CANCELLED);
}
}
get completed(): boolean {
return this.status !== undefined || this.error !== undefined;
}
private invokeCallback(func: () => {}) {
try {
func();
} catch (err: unknown) {
if (err instanceof Error) {
console.error(
`An exception was raised while invoking a callback: ${err}`
);
this.callbackException = err;
}
console.error(`Unexpected item thrown while invoking callback: ${err}`);
}
}
handleResponse(response: Message): void {
this.responses.push(response);
this.responseQueue.push(response);
this.invokeCallback(() => this.onNext(response));
}
handleCompletion(status: Status) {
this.status = status;
this.responseQueue.push(undefined);
this.invokeCallback(() => this.onCompleted(status));
}
handleError(error: Status): void {
this.error = error;
this.responseQueue.push(undefined);
this.invokeCallback(() => this.onError(error));
}
private async queuePopWithTimeout(
timeoutMs: number
): Promise<Message | undefined> {
return new Promise(async (resolve, reject) => {
let timeoutExpired = false;
const timeoutWatcher = setTimeout(() => {
timeoutExpired = true;
reject(new RpcTimeout(this.rpc, timeoutMs));
}, timeoutMs);
const response = await this.responseQueue.shift();
if (timeoutExpired) {
this.responseQueue.unshift(response);
return;
}
clearTimeout(timeoutWatcher);
resolve(response);
});
}
/**
* Yields responses up the specified count as they are added.
*
* Throws an error as soon as it is received even if there are still
* responses in the queue.
*
* Usage
* ```
* for await (const response of call.getResponses(5)) {
* console.log(response);
* }
* ```
*
* @param {number} count The number of responses to read before returning.
* If no value is specified, getResponses will block until the stream
* either ends or hits an error.
* @param {number} timeout The number of milliseconds to wait for a response
* before throwing an error.
*/
async *getResponses(
count?: number,
timeoutMs?: number
): AsyncGenerator<Message> {
this.checkErrors();
if (this.completed && this.responseQueue.length == 0) {
return;
}
let remaining = count ?? Number.POSITIVE_INFINITY;
while (remaining > 0) {
const response =
timeoutMs === undefined
? await this.responseQueue.shift()
: await this.queuePopWithTimeout(timeoutMs!);
this.checkErrors();
if (response === undefined) {
return;
}
yield response!;
remaining -= 1;
}
}
cancel(): boolean {
if (this.completed) {
return false;
}
this.error = Status.CANCELLED;
return this.rpcs.sendCancel(this.rpc);
}
private checkErrors(): void {
if (this.callbackException !== undefined) {
throw this.callbackException;
}
if (this.error !== undefined) {
throw new RpcError(this.rpc, this.error);
}
}
protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> {
for await (const response of this.getResponses(1, timeoutMs)) {
}
if (this.status === undefined) {
throw Error('Unexpected undefined status at end of stream');
}
if (this.responses.length !== 1) {
throw Error(`Unexpected number of responses: ${this.responses.length}`);
}
return [this.status!, this.responses[0]];
}
protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> {
for await (const response of this.getResponses(undefined, timeoutMs)) {
}
if (this.status === undefined) {
throw Error('Unexpected undefined status at end of stream');
}
return [this.status!, this.responses];
}
protected sendClientStream(request: Message) {
this.checkErrors();
if (this.status !== undefined) {
throw new RpcError(this.rpc, Status.FAILED_PRECONDITION);
}
this.rpcs.sendClientStream(this.rpc, request);
}
protected finishClientStream(requests: Message[]) {
for (const request of requests) {
this.sendClientStream(request);
}
if (!this.completed) {
this.rpcs.sendClientStreamEnd(this.rpc);
}
}
}
/** Tracks the state of a unary RPC call. */
export class UnaryCall extends Call {
/** Awaits the server response */
async complete(timeoutMs?: number): Promise<[Status, Message]> {
return await this.unaryWait(timeoutMs);
}
}
/** Tracks the state of a client streaming RPC call. */
export class ClientStreamingCall extends Call {
/** Gets the last server message, if it exists */
get response(): Message | undefined {
return this.responses.length > 0
? this.responses[this.responses.length - 1]
: undefined;
}
/** Sends a message from the client. */
send(request: Message) {
this.sendClientStream(request);
}
/** Ends the client stream and waits for the RPC to complete. */
async finishAndWait(
requests: Message[] = [],
timeoutMs?: number
): Promise<[Status, Message[]]> {
this.finishClientStream(requests);
return await this.streamWait(timeoutMs);
}
}
/** Tracks the state of a server streaming RPC call. */
export class ServerStreamingCall extends Call {
complete(timeoutMs?: number): Promise<[Status, Message[]]> {
return this.streamWait(timeoutMs);
}
}
/** Tracks the state of a bidirectional streaming RPC call. */
export class BidirectionalStreamingCall extends Call {
/** Sends a message from the client. */
send(request: Message) {
this.sendClientStream(request);
}
/** Ends the client stream and waits for the RPC to complete. */
async finishAndWait(
requests: Array<Message> = [],
timeoutMs?: number
): Promise<[Status, Array<Message>]> {
this.finishClientStream(requests);
return await this.streamWait(timeoutMs);
}
}