blob: db4ceced851c222be781f53ae5dfcc979f0548ac [file] [log] [blame]
// Copyright 2024 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#include "pw_channel/epoll_channel.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include "gtest/gtest.h"
#include "pw_assert/check.h"
#include "pw_async2/dispatcher.h"
#include "pw_bytes/array.h"
#include "pw_bytes/suffix.h"
#include "pw_channel/channel.h"
#include "pw_multibuf/simple_allocator_for_test.h"
#include "pw_status/status.h"
#include "pw_thread/sleep.h"
#include "pw_thread/thread.h"
#include "pw_thread_stl/options.h"
namespace {
using namespace std::chrono_literals;
using ::pw::async2::Context;
using ::pw::async2::Dispatcher;
using ::pw::async2::Pending;
using ::pw::async2::Poll;
using ::pw::async2::Ready;
using ::pw::async2::Task;
using ::pw::channel::ByteReader;
using ::pw::channel::ByteWriter;
using ::pw::channel::EpollChannel;
using ::pw::multibuf::MultiBuf;
using ::pw::multibuf::test::SimpleAllocatorForTest;
template <typename ChannelKind>
class ReaderTask : public Task {
public:
ReaderTask(ChannelKind& channel, int num_reads)
: channel_(channel), num_reads_(num_reads) {}
int poll_count = 0;
int read_count = 0;
int bytes_read = 0;
pw::Status read_status = pw::Status::Unknown();
private:
Poll<> DoPend(Context& cx) final {
++poll_count;
while (read_count < num_reads_) {
auto result = channel_.PendRead(cx);
if (result.IsPending()) {
return Pending();
}
read_status = result->status();
if (!result->ok()) {
// We hit an error-- call it quits.
return Ready();
}
++read_count;
bytes_read += (**result).size();
(**result).Release();
}
return Ready();
}
ChannelKind& channel_;
int num_reads_;
};
template <typename ChannelKind>
class CloseTask : public Task {
public:
CloseTask(ChannelKind& channel) : channel_(channel) {}
pw::Status close_status = pw::Status::Unknown();
private:
Poll<> DoPend(Context& cx) final {
auto result = channel_.PendClose(cx);
if (result.IsPending()) {
return Pending();
}
close_status = *result;
return Ready();
}
ChannelKind& channel_;
};
class EpollChannelTest : public ::testing::Test {
protected:
EpollChannelTest() {
int pipefd[2];
PW_CHECK_INT_NE(pipe(pipefd), -1);
read_fd_ = pipefd[0];
write_fd_ = pipefd[1];
}
~EpollChannelTest() override {
close(read_fd_);
close(write_fd_);
}
int read_fd_;
int write_fd_;
};
template <typename Func>
class FunctionThread : public pw::thread::ThreadCore {
public:
explicit FunctionThread(Func&& func) : func_(std::move(func)) {}
private:
void Run() override { func_(); }
Func func_;
};
TEST_F(EpollChannelTest, Read_ValidData_Succeeds) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(read_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
ReaderTask<ByteReader> read_task(channel, 1);
dispatcher.Post(read_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Pending());
EXPECT_EQ(read_task.poll_count, 1);
EXPECT_EQ(read_task.read_count, 0);
EXPECT_EQ(read_task.bytes_read, 0);
FunctionThread delayed_write([this]() {
pw::this_thread::sleep_for(500ms);
const char* data = "hello world";
PW_CHECK_INT_EQ(write(write_fd_, data, 11), 11);
});
pw::thread::Thread work_thread(pw::thread::stl::Options(), delayed_write);
work_thread.join();
dispatcher.RunToCompletion();
EXPECT_EQ(read_task.read_status, pw::OkStatus());
EXPECT_EQ(read_task.poll_count, 2);
EXPECT_EQ(read_task.read_count, 1);
EXPECT_EQ(read_task.bytes_read, 11);
CloseTask close_task(channel);
dispatcher.Post(close_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(close_task.close_status, pw::OkStatus());
}
TEST_F(EpollChannelTest, Read_Unregistered_ReturnsFailedPrecondition) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(read_fd_, dispatcher, alloc);
ReaderTask<ByteReader> read_task(channel, 1);
dispatcher.Post(read_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(read_task.read_status, pw::Status::FailedPrecondition());
}
TEST_F(EpollChannelTest, Read_Closed_ReturnsFailedPrecondition) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(read_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
CloseTask close_task(channel);
dispatcher.Post(close_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(close_task.close_status, pw::OkStatus());
ReaderTask<ByteReader> read_task(channel, 1);
dispatcher.Post(read_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(read_task.read_status, pw::Status::FailedPrecondition());
}
template <typename ChannelKind>
class WriterTask : public Task {
public:
WriterTask(ChannelKind& channel,
int num_writes,
pw::ConstByteSpan data_to_write)
: max_writes(num_writes),
channel_(channel),
data_to_write_(data_to_write) {}
int poll_count = 0;
int write_pending_count = 0;
int write_count = 0;
int max_writes = 0;
pw::Status last_write_status = pw::Status::Unknown();
pw::channel::WriteToken flushed_write_token;
private:
Poll<> DoPend(Context& cx) final {
++poll_count;
while (write_count < max_writes) {
auto result = channel_.PendReadyToWrite(cx);
if (result.IsPending()) {
++write_pending_count;
return Pending();
}
last_write_status = *result;
if (!result->ok()) {
// We hit an error-- call it quits.
return Ready();
}
++write_count;
std::optional<pw::multibuf::MultiBuf> multibuf =
channel_.GetWriteAllocator().Allocate(data_to_write_.size());
PW_CHECK(multibuf.has_value());
std::copy(
data_to_write_.begin(), data_to_write_.end(), multibuf->begin());
last_write_status = channel_.Write(std::move(*multibuf)).status();
auto token = channel_.PendFlush(cx);
if (token.IsPending()) {
return Pending();
}
PW_CHECK_OK(token->status());
flushed_write_token = **token;
}
return Ready();
}
ChannelKind& channel_;
pw::ConstByteSpan data_to_write_;
};
TEST_F(EpollChannelTest, Write_ValidData_Succeeds) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(write_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
constexpr auto kData = pw::bytes::Initialized<32>(0x3f);
WriterTask<ByteWriter> write_task(channel, 1, kData);
dispatcher.Post(write_task);
dispatcher.RunToCompletion();
EXPECT_EQ(write_task.last_write_status, pw::OkStatus());
std::array<std::byte, 64> buffer;
EXPECT_EQ(read(read_fd_, buffer.data(), buffer.size()),
static_cast<int>(kData.size()));
EXPECT_EQ(std::memcmp(buffer.data(), kData.data(), kData.size()), 0);
CloseTask close_task(channel);
dispatcher.Post(close_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(close_task.close_status, pw::OkStatus());
}
TEST_F(EpollChannelTest, Write_EmptyData_Succeeds) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(write_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
WriterTask<ByteWriter> write_task(channel, 1, {});
dispatcher.Post(write_task);
dispatcher.RunToCompletion();
EXPECT_EQ(write_task.last_write_status, pw::OkStatus());
CloseTask close_task(channel);
dispatcher.Post(close_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(close_task.close_status, pw::OkStatus());
}
TEST_F(EpollChannelTest, Write_Unregistered_ReturnsFailedPrecondition) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(write_fd_, dispatcher, alloc);
WriterTask<ByteWriter> write_task(channel, 1, {});
dispatcher.Post(write_task);
dispatcher.RunToCompletion();
EXPECT_EQ(write_task.last_write_status, pw::Status::FailedPrecondition());
}
TEST_F(EpollChannelTest, Write_Closed_ReturnsFailedPrecondition) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(write_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
CloseTask close_task(channel);
dispatcher.Post(close_task);
EXPECT_EQ(dispatcher.RunUntilStalled(), Ready());
EXPECT_EQ(close_task.close_status, pw::OkStatus());
WriterTask<ByteWriter> write_task(channel, 1, {});
dispatcher.Post(write_task);
dispatcher.RunToCompletion();
EXPECT_EQ(write_task.last_write_status, pw::Status::FailedPrecondition());
}
TEST_F(EpollChannelTest, Destructor_ClosesFileDescriptor) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
{
EpollChannel channel(write_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
}
const char kArbitraryByte = 'b';
EXPECT_EQ(write(write_fd_, &kArbitraryByte, 1), -1);
EXPECT_EQ(errno, EBADF);
}
TEST_F(EpollChannelTest, PendReadyToWrite_BlocksWhenUnavailable) {
SimpleAllocatorForTest alloc;
Dispatcher dispatcher;
EpollChannel channel(write_fd_, dispatcher, alloc);
ASSERT_EQ(channel.Register(), pw::OkStatus());
constexpr auto kData =
pw::bytes::Initialized<decltype(alloc)::data_size_bytes()>('c');
WriterTask<ByteWriter> write_task(
channel,
100, // Max writes set to some high number so the task fills the pipe.
pw::ConstByteSpan(kData));
dispatcher.Post(write_task);
// Try to write a bunch of data, eventually filling the pipe and blocking the
// task.
EXPECT_EQ(dispatcher.RunUntilStalled(), Pending());
EXPECT_EQ(write_task.poll_count, 1);
EXPECT_EQ(write_task.write_pending_count, 1);
EXPECT_EQ(write_task.last_write_status, pw::Status::Unavailable());
const int writes_to_drain = write_task.write_count;
// End the task on the next successful write.
write_task.max_writes = write_task.write_count + 1;
// Drain the pipe to make it writable again after a delay.
FunctionThread delayed_read([this, writes_to_drain]() {
pw::this_thread::sleep_for(500ms);
for (int i = 0; i < writes_to_drain; ++i) {
std::array<std::byte, decltype(alloc)::data_size_bytes()> buffer;
PW_CHECK_INT_GT(read(read_fd_, buffer.data(), buffer.size()), 0);
}
});
pw::thread::Thread work_thread(pw::thread::stl::Options(), delayed_read);
dispatcher.RunToCompletion();
work_thread.join();
EXPECT_EQ(write_task.poll_count, 2);
EXPECT_EQ(write_task.last_write_status, pw::OkStatus());
}
} // namespace