| // Copyright 2022 The Pigweed Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| // use this file except in compliance with the License. You may obtain a copy of |
| // the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| // License for the specific language governing permissions and limitations under |
| // the License. |
| #pragma once |
| |
| #include <cstdint> |
| |
| #include "pw_assert/assert.h" |
| #include "pw_chrono/system_clock.h" |
| #include "pw_function/function.h" |
| #include "pw_preprocessor/compiler.h" |
| #include "pw_rpc/raw/client_reader_writer.h" |
| #include "pw_rpc/raw/server_reader_writer.h" |
| #include "pw_span/span.h" |
| #include "pw_sync/binary_semaphore.h" |
| #include "pw_sync/timed_thread_notification.h" |
| #include "pw_thread/thread_core.h" |
| #include "pw_transfer/handler.h" |
| #include "pw_transfer/internal/client_context.h" |
| #include "pw_transfer/internal/context.h" |
| #include "pw_transfer/internal/event.h" |
| #include "pw_transfer/internal/server_context.h" |
| |
| namespace pw::transfer { |
| namespace internal { |
| |
| class TransferThread : public thread::ThreadCore { |
| public: |
| TransferThread(span<ClientContext> client_transfers, |
| span<ServerContext> server_transfers, |
| ByteSpan chunk_buffer, |
| ByteSpan encode_buffer) |
| : client_transfers_(client_transfers), |
| server_transfers_(server_transfers), |
| chunk_buffer_(chunk_buffer), |
| encode_buffer_(encode_buffer) {} |
| |
| void StartClientTransfer(TransferType type, |
| ProtocolVersion version, |
| uint32_t resource_id, |
| stream::Stream* stream, |
| const TransferParameters& max_parameters, |
| Function<void(Status)>&& on_completion, |
| chrono::SystemClock::duration timeout, |
| uint8_t max_retries) { |
| uint32_t session_id = version == ProtocolVersion::kLegacy |
| ? resource_id |
| : Context::kUnassignedSessionId; |
| StartTransfer(type, |
| version, |
| session_id, |
| resource_id, |
| /*raw_chunk=*/{}, |
| stream, |
| max_parameters, |
| std::move(on_completion), |
| timeout, |
| max_retries); |
| } |
| |
| void StartServerTransfer(TransferType type, |
| ProtocolVersion version, |
| uint32_t session_id, |
| uint32_t resource_id, |
| ConstByteSpan raw_chunk, |
| const TransferParameters& max_parameters, |
| chrono::SystemClock::duration timeout, |
| uint8_t max_retries) { |
| StartTransfer(type, |
| version, |
| session_id, |
| resource_id, |
| raw_chunk, |
| /*stream=*/nullptr, |
| max_parameters, |
| /*on_completion=*/nullptr, |
| timeout, |
| max_retries); |
| } |
| |
| void ProcessClientChunk(ConstByteSpan chunk) { |
| ProcessChunk(EventType::kClientChunk, chunk); |
| } |
| |
| void ProcessServerChunk(ConstByteSpan chunk) { |
| ProcessChunk(EventType::kServerChunk, chunk); |
| } |
| |
| void EndClientTransfer(uint32_t session_id, |
| Status status, |
| bool send_status_chunk = false) { |
| EndTransfer( |
| EventType::kClientEndTransfer, session_id, status, send_status_chunk); |
| } |
| |
| void EndServerTransfer(uint32_t session_id, |
| Status status, |
| bool send_status_chunk = false) { |
| EndTransfer( |
| EventType::kServerEndTransfer, session_id, status, send_status_chunk); |
| } |
| |
| void SetClientReadStream(rpc::RawClientReaderWriter& read_stream) { |
| SetClientStream(TransferStream::kClientRead, read_stream); |
| } |
| |
| void SetClientWriteStream(rpc::RawClientReaderWriter& write_stream) { |
| SetClientStream(TransferStream::kClientWrite, write_stream); |
| } |
| |
| void SetServerReadStream(rpc::RawServerReaderWriter& read_stream) { |
| SetServerStream(TransferStream::kServerRead, read_stream); |
| } |
| |
| void SetServerWriteStream(rpc::RawServerReaderWriter& write_stream) { |
| SetServerStream(TransferStream::kServerWrite, write_stream); |
| } |
| |
| void AddTransferHandler(Handler& handler) { |
| TransferHandlerEvent(EventType::kAddTransferHandler, handler); |
| } |
| |
| void RemoveTransferHandler(Handler& handler) { |
| TransferHandlerEvent(EventType::kRemoveTransferHandler, handler); |
| // Ensure this function blocks until the transfer handler is fully cleaned |
| // up. |
| WaitUntilEventIsProcessed(); |
| } |
| |
| size_t max_chunk_size() const { return chunk_buffer_.size(); } |
| |
| // For testing only: terminates the transfer thread with a kTerminate event. |
| void Terminate(); |
| |
| // For testing only: blocks until the next event can be acquired, which means |
| // a previously enqueued event has been processed. |
| void WaitUntilEventIsProcessed() { |
| next_event_ownership_.acquire(); |
| next_event_ownership_.release(); |
| } |
| |
| // For testing only: simulates a timeout event for a client transfer. |
| void SimulateClientTimeout(uint32_t session_id) { |
| SimulateTimeout(EventType::kClientTimeout, session_id); |
| } |
| |
| // For testing only: simulates a timeout event for a server transfer. |
| void SimulateServerTimeout(uint32_t session_id) { |
| SimulateTimeout(EventType::kServerTimeout, session_id); |
| } |
| |
| private: |
| friend class Context; |
| |
| // Maximum amount of time between transfer thread runs. |
| static constexpr chrono::SystemClock::duration kMaxTimeout = |
| std::chrono::seconds(2); |
| |
| // Finds an active server or client transfer. |
| template <typename T> |
| static Context* FindActiveTransfer(const span<T>& transfers, |
| uint32_t context_identifier) { |
| auto transfer = std::find_if( |
| transfers.begin(), transfers.end(), [context_identifier](auto& c) { |
| return c.initialized() && c.id() == context_identifier; |
| }); |
| return transfer != transfers.end() ? &*transfer : nullptr; |
| } |
| |
| void SimulateTimeout(EventType type, uint32_t session_id); |
| |
| // Finds an new server or client transfer. |
| template <typename T> |
| static Context* FindNewTransfer(const span<T>& transfers, |
| uint32_t session_id) { |
| Context* new_transfer = nullptr; |
| |
| for (Context& context : transfers) { |
| if (context.active()) { |
| if (context.session_id() == session_id) { |
| // Restart an already active transfer. |
| return &context; |
| } |
| } else { |
| // Store the inactive context as an option, but keep checking for the |
| // restart case. |
| new_transfer = &context; |
| } |
| } |
| |
| return new_transfer; |
| } |
| |
| const ByteSpan& encode_buffer() const { return encode_buffer_; } |
| |
| void Run() final; |
| |
| void HandleTimeouts(); |
| |
| rpc::Writer& stream_for(TransferStream stream) { |
| switch (stream) { |
| case TransferStream::kClientRead: |
| return client_read_stream_; |
| case TransferStream::kClientWrite: |
| return client_write_stream_; |
| case TransferStream::kServerRead: |
| return server_read_stream_; |
| case TransferStream::kServerWrite: |
| return server_write_stream_; |
| } |
| // An unknown TransferStream value was passed, which means this function |
| // was passed an invalid enum value. |
| PW_ASSERT(false); |
| } |
| |
| // Returns the earliest timeout among all active transfers, up to kMaxTimeout. |
| chrono::SystemClock::time_point GetNextTransferTimeout() const; |
| |
| void StartTransfer(TransferType type, |
| ProtocolVersion version, |
| uint32_t session_id, |
| uint32_t resource_id, |
| ConstByteSpan raw_chunk, |
| stream::Stream* stream, |
| const TransferParameters& max_parameters, |
| Function<void(Status)>&& on_completion, |
| chrono::SystemClock::duration timeout, |
| uint8_t max_retries); |
| |
| void ProcessChunk(EventType type, ConstByteSpan chunk); |
| |
| void EndTransfer(EventType type, |
| uint32_t session_id, |
| Status status, |
| bool send_status_chunk); |
| |
| void SetClientStream(TransferStream type, rpc::RawClientReaderWriter& stream); |
| void SetServerStream(TransferStream type, rpc::RawServerReaderWriter& stream); |
| |
| void TransferHandlerEvent(EventType type, Handler& handler); |
| |
| void HandleEvent(const Event& event); |
| Context* FindContextForEvent(const Event& event) const; |
| |
| void SendStatusChunk(const SendStatusChunkEvent& event); |
| |
| sync::TimedThreadNotification event_notification_; |
| sync::BinarySemaphore next_event_ownership_; |
| |
| Event next_event_; |
| Function<void(Status)> staged_on_completion_; |
| rpc::RawClientReaderWriter staged_client_stream_; |
| rpc::RawServerReaderWriter staged_server_stream_; |
| |
| rpc::RawClientReaderWriter client_read_stream_; |
| rpc::RawClientReaderWriter client_write_stream_; |
| rpc::RawServerReaderWriter server_read_stream_; |
| rpc::RawServerReaderWriter server_write_stream_; |
| |
| span<ClientContext> client_transfers_; |
| span<ServerContext> server_transfers_; |
| |
| // All registered transfer handlers. |
| IntrusiveList<Handler> handlers_; |
| |
| // Buffer in which chunk data is staged for CHUNK events. |
| ByteSpan chunk_buffer_; |
| |
| // Buffer into which responses are encoded. Only ever used from within the |
| // transfer thread, so no locking is required. |
| ByteSpan encode_buffer_; |
| }; |
| |
| } // namespace internal |
| |
| using TransferThread = internal::TransferThread; |
| |
| template <size_t kMaxConcurrentClientTransfers, |
| size_t kMaxConcurrentServerTransfers> |
| class Thread final : public internal::TransferThread { |
| public: |
| Thread(ByteSpan chunk_buffer, ByteSpan encode_buffer) |
| : internal::TransferThread( |
| client_contexts_, server_contexts_, chunk_buffer, encode_buffer) {} |
| |
| private: |
| std::array<internal::ClientContext, kMaxConcurrentClientTransfers> |
| client_contexts_; |
| std::array<internal::ServerContext, kMaxConcurrentServerTransfers> |
| server_contexts_; |
| }; |
| |
| } // namespace pw::transfer |