| // Copyright 2022 The Pigweed Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| // use this file except in compliance with the License. You may obtain a copy of |
| // the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| // License for the specific language governing permissions and limitations under |
| // the License. |
| |
| #define PW_LOG_MODULE_NAME "TRN" |
| |
| #include "pw_transfer/transfer_thread.h" |
| |
| #include "pw_assert/check.h" |
| #include "pw_log/log.h" |
| #include "pw_transfer/internal/chunk.h" |
| |
| PW_MODIFY_DIAGNOSTICS_PUSH(); |
| PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); |
| |
| namespace pw::transfer::internal { |
| |
| void TransferThread::Terminate() { |
| next_event_ownership_.acquire(); |
| next_event_.type = EventType::kTerminate; |
| event_notification_.release(); |
| } |
| |
| void TransferThread::SimulateTimeout(EventType type, uint32_t session_id) { |
| next_event_ownership_.acquire(); |
| |
| next_event_.type = type; |
| next_event_.chunk = {}; |
| next_event_.chunk.context_identifier = session_id; |
| |
| event_notification_.release(); |
| |
| WaitUntilEventIsProcessed(); |
| } |
| |
| void TransferThread::Run() { |
| // Next event starts freed. |
| next_event_ownership_.release(); |
| |
| while (true) { |
| if (event_notification_.try_acquire_until(GetNextTransferTimeout())) { |
| HandleEvent(next_event_); |
| |
| // Sample event type before we release ownership of next_event_. |
| bool is_terminating = next_event_.type == EventType::kTerminate; |
| |
| // Finished processing the event. Allow the next_event struct to be |
| // overwritten. |
| next_event_ownership_.release(); |
| |
| if (is_terminating) { |
| return; |
| } |
| } |
| |
| // Regardless of whether an event was received or not, check for any |
| // transfers which have timed out and process them if so. |
| for (Context& context : client_transfers_) { |
| if (context.timed_out()) { |
| context.HandleEvent({.type = EventType::kClientTimeout}); |
| } |
| } |
| for (Context& context : server_transfers_) { |
| if (context.timed_out()) { |
| context.HandleEvent({.type = EventType::kServerTimeout}); |
| } |
| } |
| } |
| } |
| |
| chrono::SystemClock::time_point TransferThread::GetNextTransferTimeout() const { |
| chrono::SystemClock::time_point timeout = |
| chrono::SystemClock::TimePointAfterAtLeast(kMaxTimeout); |
| |
| for (Context& context : client_transfers_) { |
| auto ctx_timeout = context.timeout(); |
| if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) { |
| timeout = ctx_timeout.value(); |
| } |
| } |
| for (Context& context : server_transfers_) { |
| auto ctx_timeout = context.timeout(); |
| if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) { |
| timeout = ctx_timeout.value(); |
| } |
| } |
| |
| return timeout; |
| } |
| |
| void TransferThread::StartTransfer(TransferType type, |
| ProtocolVersion version, |
| uint32_t session_id, |
| uint32_t resource_id, |
| ConstByteSpan raw_chunk, |
| stream::Stream* stream, |
| const TransferParameters& max_parameters, |
| Function<void(Status)>&& on_completion, |
| chrono::SystemClock::duration timeout, |
| uint8_t max_retries) { |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| bool is_client_transfer = stream != nullptr; |
| |
| next_event_.type = is_client_transfer ? EventType::kNewClientTransfer |
| : EventType::kNewServerTransfer; |
| |
| if (!raw_chunk.empty()) { |
| std::memcpy(chunk_buffer_.data(), raw_chunk.data(), raw_chunk.size()); |
| } |
| |
| next_event_.new_transfer = { |
| .type = type, |
| .protocol_version = version, |
| .session_id = session_id, |
| .resource_id = resource_id, |
| .max_parameters = &max_parameters, |
| .timeout = timeout, |
| .max_retries = max_retries, |
| .transfer_thread = this, |
| .raw_chunk_data = chunk_buffer_.data(), |
| .raw_chunk_size = raw_chunk.size(), |
| }; |
| |
| staged_on_completion_ = std::move(on_completion); |
| |
| // The transfer is initialized with either a stream (client-side) or a handler |
| // (server-side). If no stream is provided, try to find a registered handler |
| // with the specified ID. |
| if (is_client_transfer) { |
| next_event_.new_transfer.stream = stream; |
| next_event_.new_transfer.rpc_writer = &static_cast<rpc::Writer&>( |
| type == TransferType::kTransmit ? client_write_stream_ |
| : client_read_stream_); |
| } else { |
| auto handler = std::find_if(handlers_.begin(), |
| handlers_.end(), |
| [&](auto& h) { return h.id() == resource_id; }); |
| if (handler != handlers_.end()) { |
| next_event_.new_transfer.handler = &*handler; |
| next_event_.new_transfer.rpc_writer = &static_cast<rpc::Writer&>( |
| type == TransferType::kTransmit ? server_read_stream_ |
| : server_write_stream_); |
| } else { |
| // No handler exists for the transfer: return a NOT_FOUND. |
| next_event_.type = EventType::kSendStatusChunk; |
| next_event_.send_status_chunk = { |
| .session_id = session_id, |
| .protocol_version = version, |
| .status = Status::NotFound().code(), |
| .stream = type == TransferType::kTransmit |
| ? TransferStream::kServerRead |
| : TransferStream::kServerWrite, |
| }; |
| } |
| } |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::ProcessChunk(EventType type, ConstByteSpan chunk) { |
| // If this assert is hit, there is a bug in the transfer implementation. |
| // Contexts' max_chunk_size_bytes fields should be set based on the size of |
| // chunk_buffer_. |
| PW_CHECK(chunk.size() <= chunk_buffer_.size(), |
| "Transfer received a larger chunk than it can handle."); |
| |
| Result<uint32_t> identifier = Chunk::ExtractIdentifier(chunk); |
| if (!identifier.ok()) { |
| PW_LOG_ERROR("Received a malformed chunk without a context identifier"); |
| return; |
| } |
| |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| std::memcpy(chunk_buffer_.data(), chunk.data(), chunk.size()); |
| |
| next_event_.type = type; |
| next_event_.chunk = { |
| .context_identifier = *identifier, |
| .data = chunk_buffer_.data(), |
| .size = chunk.size(), |
| }; |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::EndTransfer(EventType type, |
| uint32_t session_id, |
| Status status, |
| bool send_status_chunk) { |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| next_event_.type = type; |
| next_event_.end_transfer = { |
| .session_id = session_id, |
| .status = status.code(), |
| .send_status_chunk = send_status_chunk, |
| }; |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::SetClientStream(TransferStream type, |
| rpc::RawClientReaderWriter& stream) { |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| next_event_.type = EventType::kSetTransferStream; |
| next_event_.set_transfer_stream = type; |
| staged_client_stream_ = std::move(stream); |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::SetServerStream(TransferStream type, |
| rpc::RawServerReaderWriter& stream) { |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| next_event_.type = EventType::kSetTransferStream; |
| next_event_.set_transfer_stream = type; |
| staged_server_stream_ = std::move(stream); |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::TransferHandlerEvent(EventType type, Handler& handler) { |
| // Block until the last event has been processed. |
| next_event_ownership_.acquire(); |
| |
| next_event_.type = type; |
| if (type == EventType::kAddTransferHandler) { |
| next_event_.add_transfer_handler = &handler; |
| } else { |
| next_event_.remove_transfer_handler = &handler; |
| } |
| |
| event_notification_.release(); |
| } |
| |
| void TransferThread::HandleEvent(const internal::Event& event) { |
| switch (event.type) { |
| case EventType::kTerminate: |
| // Terminate server contexts. |
| for (ServerContext& server_context : server_transfers_) { |
| server_context.HandleEvent({ |
| .type = EventType::kServerEndTransfer, |
| .end_transfer = |
| { |
| .session_id = server_context.session_id(), |
| .status = Status::Aborted().code(), |
| .send_status_chunk = false, |
| }, |
| }); |
| } |
| |
| // Terminate client contexts. |
| for (ClientContext& client_context : client_transfers_) { |
| client_context.HandleEvent({ |
| .type = EventType::kClientEndTransfer, |
| .end_transfer = |
| { |
| .session_id = client_context.session_id(), |
| .status = Status::Aborted().code(), |
| .send_status_chunk = false, |
| }, |
| }); |
| } |
| |
| // Cancel/Finish streams. |
| client_read_stream_.Cancel().IgnoreError(); |
| client_write_stream_.Cancel().IgnoreError(); |
| server_read_stream_.Finish(Status::Aborted()).IgnoreError(); |
| server_write_stream_.Finish(Status::Aborted()).IgnoreError(); |
| return; |
| |
| case EventType::kSendStatusChunk: |
| SendStatusChunk(event.send_status_chunk); |
| break; |
| |
| case EventType::kSetTransferStream: |
| switch (event.set_transfer_stream) { |
| case TransferStream::kClientRead: |
| client_read_stream_ = std::move(staged_client_stream_); |
| break; |
| |
| case TransferStream::kClientWrite: |
| client_write_stream_ = std::move(staged_client_stream_); |
| break; |
| |
| case TransferStream::kServerRead: |
| server_read_stream_ = std::move(staged_server_stream_); |
| break; |
| |
| case TransferStream::kServerWrite: |
| server_write_stream_ = std::move(staged_server_stream_); |
| break; |
| } |
| return; |
| |
| case EventType::kAddTransferHandler: |
| handlers_.push_front(*event.add_transfer_handler); |
| return; |
| |
| case EventType::kRemoveTransferHandler: |
| for (ServerContext& server_context : server_transfers_) { |
| if (server_context.handler() == event.remove_transfer_handler) { |
| server_context.HandleEvent({ |
| .type = EventType::kServerEndTransfer, |
| .end_transfer = |
| { |
| .session_id = server_context.session_id(), |
| .status = Status::Aborted().code(), |
| .send_status_chunk = false, |
| }, |
| }); |
| } |
| } |
| handlers_.remove(*event.remove_transfer_handler); |
| return; |
| |
| case EventType::kNewClientTransfer: |
| case EventType::kNewServerTransfer: |
| case EventType::kClientChunk: |
| case EventType::kServerChunk: |
| case EventType::kClientTimeout: |
| case EventType::kServerTimeout: |
| case EventType::kClientEndTransfer: |
| case EventType::kServerEndTransfer: |
| default: |
| // Other events are handled by individual transfer contexts. |
| break; |
| } |
| |
| Context* ctx = FindContextForEvent(event); |
| if (ctx == nullptr) { |
| // No context was found. For new transfer events, report a |
| // RESOURCE_EXHAUSTED error with starting the transfer. |
| if (event.type == EventType::kNewClientTransfer) { |
| // On the client, invoke the completion callback directly. |
| staged_on_completion_(Status::ResourceExhausted()); |
| } else if (event.type == EventType::kNewServerTransfer) { |
| // On the server, send a status chunk back to the client. |
| SendStatusChunk( |
| {.session_id = event.new_transfer.session_id, |
| .protocol_version = event.new_transfer.protocol_version, |
| .status = Status::ResourceExhausted().code(), |
| .stream = event.new_transfer.type == TransferType::kTransmit |
| ? TransferStream::kServerRead |
| : TransferStream::kServerWrite}); |
| } |
| return; |
| } |
| |
| if (event.type == EventType::kNewClientTransfer) { |
| // TODO(frolv): This is terrible. |
| static_cast<ClientContext*>(ctx)->set_on_completion( |
| std::move(staged_on_completion_)); |
| } |
| |
| ctx->HandleEvent(event); |
| } |
| |
| Context* TransferThread::FindContextForEvent( |
| const internal::Event& event) const { |
| switch (event.type) { |
| case EventType::kNewClientTransfer: |
| return FindNewTransfer(client_transfers_, event.new_transfer.session_id); |
| case EventType::kNewServerTransfer: |
| return FindNewTransfer(server_transfers_, event.new_transfer.session_id); |
| |
| case EventType::kClientChunk: |
| return FindActiveTransfer(client_transfers_, |
| event.chunk.context_identifier); |
| case EventType::kServerChunk: |
| return FindActiveTransfer(server_transfers_, |
| event.chunk.context_identifier); |
| |
| case EventType::kClientTimeout: // Manually triggered client timeout |
| return FindActiveTransfer(client_transfers_, |
| event.chunk.context_identifier); |
| case EventType::kServerTimeout: // Manually triggered server timeout |
| return FindActiveTransfer(server_transfers_, |
| event.chunk.context_identifier); |
| |
| case EventType::kClientEndTransfer: |
| return FindActiveTransfer(client_transfers_, |
| event.end_transfer.session_id); |
| case EventType::kServerEndTransfer: |
| return FindActiveTransfer(server_transfers_, |
| event.end_transfer.session_id); |
| |
| case EventType::kSendStatusChunk: |
| case EventType::kSetTransferStream: |
| case EventType::kAddTransferHandler: |
| case EventType::kRemoveTransferHandler: |
| case EventType::kTerminate: |
| default: |
| return nullptr; |
| } |
| } |
| |
| void TransferThread::SendStatusChunk( |
| const internal::SendStatusChunkEvent& event) { |
| rpc::Writer& destination = stream_for(event.stream); |
| |
| Result<ConstByteSpan> result = |
| Chunk::Final(event.protocol_version, event.session_id, event.status) |
| .Encode(chunk_buffer_); |
| |
| if (!result.ok()) { |
| PW_LOG_ERROR("Failed to encode final chunk for transfer %u", |
| static_cast<unsigned>(event.session_id)); |
| return; |
| } |
| |
| if (!destination.Write(result.value()).ok()) { |
| PW_LOG_ERROR("Failed to send final chunk for transfer %u", |
| static_cast<unsigned>(event.session_id)); |
| return; |
| } |
| } |
| |
| } // namespace pw::transfer::internal |
| |
| PW_MODIFY_DIAGNOSTICS_POP(); |