blob: f4e8189f760e67d50e1a85f0d334065b161f9f51 [file] [log] [blame]
// Copyright 2022 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
package dev.pigweed.pw_transfer;
import com.google.common.util.concurrent.ListenableFuture;
import dev.pigweed.pw_log.Logger;
import dev.pigweed.pw_rpc.Call;
import dev.pigweed.pw_rpc.ChannelOutputException;
import dev.pigweed.pw_rpc.MethodClient;
import dev.pigweed.pw_rpc.Status;
import dev.pigweed.pw_rpc.StreamObserver;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import javax.annotation.Nullable;
/** Manages the active transfers and dispatches events to them. */
class TransferEventHandler {
private static final Logger logger = Logger.forClass(TransferEventHandler.class);
// Instant and BlockingQueue use different time unit types.
private static final TemporalUnit TIME_UNIT = ChronoUnit.MICROS;
private static final TimeUnit POLL_TIME_UNIT = TimeUnit.MICROSECONDS;
private final MethodClient readMethod;
private final MethodClient writeMethod;
private final BlockingQueue<Event> events = new LinkedBlockingQueue<>();
// Map resource ID to transfer, and session ID to resource ID.
private final Map<Integer, Transfer<?>> resourceIdToTransfer = new HashMap<>();
private final Map<Integer, Integer> sessionToResourceId = new HashMap<>();
@Nullable private Call.ClientStreaming<Chunk> readStream = null;
@Nullable private Call.ClientStreaming<Chunk> writeStream = null;
private boolean processEvents = true;
TransferEventHandler(MethodClient readMethod, MethodClient writeMethod) {
this.readMethod = readMethod;
this.writeMethod = writeMethod;
}
ListenableFuture<Void> startWriteTransferAsClient(int resourceId,
ProtocolVersion desiredProtocolVersion,
TransferTimeoutSettings settings,
byte[] data,
Consumer<TransferProgress> progressCallback,
BooleanSupplier shouldAbortCallback) {
WriteTransfer transfer =
new WriteTransfer(resourceId, desiredProtocolVersion, new TransferInterface() {
@Override
Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
if (writeStream == null) {
writeStream = writeMethod.invokeBidirectionalStreaming(new ChunkHandler() {
@Override
void resetStream() {
writeStream = null;
}
});
}
return writeStream;
}
}, settings, data, progressCallback, shouldAbortCallback);
startTransferAsClient(transfer);
return transfer.getFuture();
}
ListenableFuture<byte[]> startReadTransferAsClient(int resourceId,
ProtocolVersion desiredProtocolVersion,
TransferTimeoutSettings settings,
TransferParameters parameters,
Consumer<TransferProgress> progressCallback,
BooleanSupplier shouldAbortCallback) {
ReadTransfer transfer =
new ReadTransfer(resourceId, desiredProtocolVersion, new TransferInterface() {
@Override
Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
if (readStream == null) {
readStream = readMethod.invokeBidirectionalStreaming(new ChunkHandler() {
@Override
void resetStream() {
readStream = null;
}
});
}
return readStream;
}
}, settings, parameters, progressCallback, shouldAbortCallback);
startTransferAsClient(transfer);
return transfer.getFuture();
}
private void startTransferAsClient(Transfer<?> transfer) {
enqueueEvent(() -> {
if (resourceIdToTransfer.containsKey(transfer.getResourceId())) {
transfer.terminate(new TransferError("A transfer for resource ID "
+ transfer.getResourceId()
+ " is already in progress! Only one read/write transfer per resource is supported at a time",
Status.ALREADY_EXISTS));
return;
}
resourceIdToTransfer.put(transfer.getResourceId(), transfer);
transfer.start();
});
}
/** Handles events until stop() is called. */
void run() {
while (processEvents) {
handleNextEvent();
handleTimeouts();
}
}
/**
* Test version of run() that processes all enqueued events before checking for timeouts.
*
* Tests that need to time out should process all enqueued events first to prevent flaky failures.
* If handling one of several queued packets takes longer than the timeout (which must be short
* for a unit test), then the test may fail spuriously.
*
* This run function is not used outside of tests because processing all incoming packets before
* checking for timeouts could delay the transfer client's outgoing write packets if there are
* lots of inbound packets. This could delay transfers and cause unnecessary timeouts.
*/
void runForTestsThatMustTimeOut() {
while (processEvents) {
while (!events.isEmpty()) {
handleNextEvent();
}
handleTimeouts();
}
}
/** Stops the transfer event handler from processing events. */
void stop() {
enqueueEvent(() -> {
logger.atFine().log("Terminating TransferEventHandler");
resourceIdToTransfer.values().forEach(Transfer::handleTermination);
processEvents = false;
});
}
/** Blocks until all events currently in the queue are processed; for test use only. */
void waitUntilEventsAreProcessedForTest() {
Semaphore semaphore = new Semaphore(0);
enqueueEvent(semaphore::release);
try {
semaphore.acquire();
} catch (InterruptedException e) {
throw new AssertionError("Unexpectedly interrupted", e);
}
}
private void enqueueEvent(Event event) {
while (true) {
try {
events.put(event);
return;
} catch (InterruptedException e) {
// Ignore and keep trying.
}
}
}
private void handleNextEvent() {
final long sleepFor = TIME_UNIT.between(Instant.now(), getNextTimeout());
try {
Event event = events.poll(sleepFor, POLL_TIME_UNIT);
if (event != null) {
event.handle();
}
} catch (InterruptedException e) {
// If interrupted, continue around the loop.
}
}
private void handleTimeouts() {
for (Transfer<?> transfer : resourceIdToTransfer.values()) {
transfer.handleTimeoutIfDeadlineExceeded();
}
}
private Instant getNextTimeout() {
Optional<Transfer<?>> transfer =
resourceIdToTransfer.values().stream().min(Comparator.comparing(Transfer::getDeadline));
return transfer.isPresent() ? transfer.get().getDeadline() : Transfer.NO_TIMEOUT;
}
/** This interface gives a Transfer access to the TransferEventHandler. */
abstract class TransferInterface {
private TransferInterface() {}
/**
* Sends the provided transfer chunk.
*
* Must be called on the transfer thread.
*/
void sendChunk(Chunk chunk) throws TransferError {
try {
getStream().write(chunk);
} catch (ChannelOutputException e) {
throw new TransferError("Failed to send chunk for write transfer", e);
}
}
/**
* Associates the transfer's session ID with its resource ID.
*
* Must be called on the transfer thread.
*/
void assignSessionId(Transfer<?> transfer) {
sessionToResourceId.put(transfer.getSessionId(), transfer.getResourceId());
}
/**
* Removes this transfer from the list of active transfers.
*
* Must be called on the transfer thread.
*/
void unregisterTransfer(Transfer<?> transfer) {
resourceIdToTransfer.remove(transfer.getResourceId());
sessionToResourceId.remove(transfer.getSessionId());
}
/**
* Initiates the cancellation process for the provided transfer.
*
* May be called from any thread.
*/
void cancelTransfer(Transfer<?> transfer) {
enqueueEvent(transfer::handleCancellation);
}
/** Gets either the read or write stream. */
abstract Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException;
}
/** Handles responses on the pw_transfer RPCs. */
private abstract class ChunkHandler implements StreamObserver<Chunk> {
@Override
public final void onNext(Chunk chunkProto) {
VersionedChunk chunk = VersionedChunk.fromMessage(chunkProto);
enqueueEvent(() -> {
Transfer<?> transfer = null;
if (chunk.resourceId().isPresent()) {
transfer = resourceIdToTransfer.get(chunk.resourceId().getAsInt());
} else {
Integer resourceId = sessionToResourceId.get(chunk.sessionId());
if (resourceId != null) {
transfer = resourceIdToTransfer.get(resourceId);
}
}
if (transfer != null) {
logger.atFinest().log("%s received chunk: %s", transfer, chunk);
transfer.handleChunk(chunk);
} else {
logger.atInfo().log("Ignoring unrecognized transfer chunk: %s", chunk);
}
});
}
@Override
public final void onCompleted(Status status) {
onError(Status.INTERNAL); // This RPC should never complete: treat as an internal error.
}
@Override
public final void onError(Status status) {
enqueueEvent(() -> {
resetStream();
// The transfers remove themselves from the Map during cleanup, iterate over a copied list.
List<Transfer<?>> activeTransfers = new ArrayList<>(resourceIdToTransfer.values());
// FAILED_PRECONDITION indicates that the stream packet was not recognized as the stream is
// not open. This could occur if the server resets. Notify pending transfers that this has
// occurred so they can restart.
if (status.equals(Status.FAILED_PRECONDITION)) {
activeTransfers.forEach(Transfer::handleDisconnection);
} else {
TransferError error = new TransferError(
"Transfer stream RPC closed unexpectedly with status " + status, Status.INTERNAL);
activeTransfers.forEach(t -> t.terminate(error));
}
});
}
abstract void resetStream();
}
// Represents an event that occurs during a transfer
private interface Event {
void handle();
}
}