blob: 47f9f1fcfbbdfb35bc021cd91079665bd1d01afd [file] [log] [blame]
// Copyright 2022 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
package dev.pigweed.pw_transfer;
import com.google.common.util.concurrent.ListenableFuture;
import dev.pigweed.pw_log.Logger;
import dev.pigweed.pw_rpc.Call;
import dev.pigweed.pw_rpc.ChannelOutputException;
import dev.pigweed.pw_rpc.MethodClient;
import dev.pigweed.pw_rpc.RpcError;
import dev.pigweed.pw_rpc.Status;
import dev.pigweed.pw_rpc.StreamObserver;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import javax.annotation.Nullable;
/** Manages the active transfers and dispatches events to them. */
class TransferEventHandler {
private static final Logger logger = Logger.forClass(TransferEventHandler.class);
// Instant and BlockingQueue use different time unit types.
private static final TemporalUnit TIME_UNIT = ChronoUnit.MICROS;
private static final TimeUnit POLL_TIME_UNIT = TimeUnit.MICROSECONDS;
private final MethodClient readMethod;
private final MethodClient writeMethod;
private final BlockingQueue<Event> events = new LinkedBlockingQueue<>();
// Maps resource ID to transfer
private final Map<Integer, Transfer<?>> transfers = new HashMap<>();
@Nullable private Call.ClientStreaming<Chunk> readStream = null;
@Nullable private Call.ClientStreaming<Chunk> writeStream = null;
private boolean processEvents = true;
TransferEventHandler(MethodClient readMethod, MethodClient writeMethod) {
this.readMethod = readMethod;
this.writeMethod = writeMethod;
}
ListenableFuture<Void> startWriteTransferAsClient(int resourceId,
int transferTimeoutMillis,
int initialTransferTimeoutMillis,
int maxRetries,
byte[] data,
Consumer<TransferProgress> progressCallback,
BooleanSupplier shouldAbortCallback) {
WriteTransfer transfer = new WriteTransfer(resourceId,
new TransferInterface() {
@Override
Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
if (writeStream == null) {
writeStream = writeMethod.invokeBidirectionalStreaming(new ChunkHandler() {
@Override
void resetStream() {
writeStream = null;
}
});
}
return writeStream;
}
},
transferTimeoutMillis,
initialTransferTimeoutMillis,
maxRetries,
data,
progressCallback,
shouldAbortCallback);
startTransferAsClient(transfer);
return transfer.getFuture();
}
ListenableFuture<byte[]> startReadTransferAsClient(int resourceId,
int transferTimeoutMillis,
int initialTransferTimeoutMillis,
int maxRetries,
TransferParameters parameters,
Consumer<TransferProgress> progressCallback,
BooleanSupplier shouldAbortCallback) {
ReadTransfer transfer = new ReadTransfer(resourceId,
new TransferInterface() {
@Override
Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
if (readStream == null) {
readStream = readMethod.invokeBidirectionalStreaming(new ChunkHandler() {
@Override
void resetStream() {
readStream = null;
}
});
}
return readStream;
}
},
transferTimeoutMillis,
initialTransferTimeoutMillis,
maxRetries,
parameters,
progressCallback,
shouldAbortCallback);
startTransferAsClient(transfer);
return transfer.getFuture();
}
private void startTransferAsClient(Transfer<?> transfer) {
enqueueEvent(() -> {
if (transfers.put(transfer.getResourceId(), transfer) != null) {
transfer.cleanUp(new TransferError("A transfer for resource ID " + transfer.getResourceId()
+ " is already in progress! Only one read/write transfer per resource is supported at a time",
Status.ALREADY_EXISTS));
return;
}
transfer.start();
});
}
/** Handles events until stop() is called. */
void run() {
while (processEvents) {
handleNextEvent();
}
}
/** Stops the transfer event handler from processing events. */
void stop() {
enqueueEvent(() -> {
logger.atFine().log("Terminating TransferEventHandler");
transfers.values().forEach(Transfer::handleTermination);
processEvents = false;
});
}
/** Blocks until all events currently in the queue are processed; for test use only. */
void waitUntilEventsAreProcessedForTest() {
Semaphore semaphore = new Semaphore(0);
enqueueEvent(semaphore::release);
try {
semaphore.acquire();
} catch (InterruptedException e) {
throw new AssertionError("Unexpectedly interrupted", e);
}
}
private void enqueueEvent(Event event) {
while (true) {
try {
events.put(event);
return;
} catch (InterruptedException e) {
// Ignore and keep trying.
}
}
}
private void handleNextEvent() {
final long sleepFor = Math.min(0, Instant.now().until(getNextTimeout(), TIME_UNIT));
try {
Event event = events.poll(sleepFor, POLL_TIME_UNIT);
if (event != null) {
event.handle();
}
} catch (InterruptedException e) {
// If interrupted, check for timeouts anyway.
}
for (Transfer<?> transfer : transfers.values()) {
transfer.handleTimeoutIfDeadlineExceeded();
}
}
private Instant getNextTimeout() {
Optional<Transfer<?>> transfer =
transfers.values().stream().min(Comparator.comparing(Transfer::getDeadline));
return transfer.isPresent() ? transfer.get().getDeadline() : Transfer.NO_TIMEOUT;
}
/** This interface gives a Transfer access to the TransferEventHandler. */
abstract class TransferInterface {
private TransferInterface() {}
/**
* Sends the provided transfer chunk.
*
* Must be called on the transfer therad.
*/
void sendChunk(Chunk chunk) throws TransferError {
try {
getStream().send(chunk);
} catch (ChannelOutputException | RpcError e) {
throw new TransferError("Failed to send chunk for write transfer", e);
}
}
/**
* Removes this transfer from the list of active transfers.
*
* Must be called on the transfer therad.
*/
void unregisterTransfer(int sessionId) {
transfers.remove(sessionId);
}
/**
* Initiates the cancellation process for the provided transfer.
*
* May be called from any thread.
*/
void cancelTransfer(Transfer<?> transfer) {
enqueueEvent(transfer::handleCancellation);
}
/** Gets either the read or write stream. */
abstract Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException;
}
/** Handles responses on the pw_transfer RPCs. */
private abstract class ChunkHandler implements StreamObserver<Chunk> {
@Override
public final void onNext(Chunk chunk) {
enqueueEvent(() -> {
Transfer<?> transfer = transfers.get(chunk.getTransferId());
if (transfer != null) {
logger.atFinest().log(
"Transfer %d received chunk: %s", transfer.getSessionId(), chunkToString(chunk));
transfer.handleChunk(chunk);
} else {
logger.atWarning().log(
"Ignoring unrecognized transfer session ID %d", chunk.getTransferId());
}
});
}
@Override
public final void onCompleted(Status status) {
onError(Status.INTERNAL); // This RPC should never complete: treat as an internal error.
}
@Override
public final void onError(Status status) {
enqueueEvent(() -> {
resetStream();
// The transfers remove themselves from the Map during cleanup, iterate over a copied list.
List<Transfer<?>> activeTransfers = new ArrayList<>(transfers.values());
// FAILED_PRECONDITION indicates that the stream packet was not recognized as the stream is
// not open. This could occur if the server resets. Notify pending transfers that this has
// occurred so they can restart.
if (status.equals(Status.FAILED_PRECONDITION)) {
activeTransfers.forEach(Transfer::handleDisconnection);
} else {
TransferError error = new TransferError(
"Transfer stream RPC closed unexpectedly with status " + status, Status.INTERNAL);
activeTransfers.forEach(t -> t.cleanUp(error));
}
});
}
abstract void resetStream();
}
private static String chunkToString(Chunk chunk) {
StringBuilder str = new StringBuilder();
str.append("transferId:").append(chunk.getTransferId()).append(" ");
str.append("windowEndOffset:").append(chunk.getWindowEndOffset()).append(" ");
str.append("offset:").append(chunk.getOffset()).append(" ");
// Don't include the actual data; it's too much.
str.append("len(data):").append(chunk.getData().size()).append(" ");
if (chunk.hasPendingBytes()) {
str.append("pendingBytes:").append(chunk.getPendingBytes()).append(" ");
}
if (chunk.hasMaxChunkSizeBytes()) {
str.append("maxChunkSizeBytes:").append(chunk.getMaxChunkSizeBytes()).append(" ");
}
if (chunk.hasMinDelayMicroseconds()) {
str.append("minDelayMicroseconds:").append(chunk.getMinDelayMicroseconds()).append(" ");
}
if (chunk.hasRemainingBytes()) {
str.append("remainingBytes:").append(chunk.getRemainingBytes()).append(" ");
}
if (chunk.hasStatus()) {
str.append("status:").append(chunk.getStatus()).append(" ");
}
if (chunk.hasType()) {
str.append("type:").append(chunk.getTypeValue()).append(" ");
}
if (chunk.hasResourceId()) {
str.append("resourceId:").append(chunk.getSessionId()).append(" ");
}
if (chunk.hasSessionId()) {
str.append("sessionId:").append(chunk.getSessionId()).append(" ");
}
return str.toString();
}
// Represents an event that occurs during a transfer
private interface Event {
void handle();
}
}