blob: e4fe726ced44b8d9864c263800bc1cf223b51bde [file] [log] [blame]
// Copyright 2023 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#include "pw_data_link/socket_data_link.h"
#if defined(_WIN32) && _WIN32
// TODO(cachinchilla): add support for windows.
#error Windows not supported yet!
#else
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#endif // defined(_WIN32) && _WIN32
#include <chrono>
#include <optional>
#include "pw_allocator/allocator.h"
#include "pw_assert/check.h"
#include "pw_bytes/span.h"
#include "pw_log/log.h"
#include "pw_multibuf/header_chunk_region_tracker.h"
#include "pw_status/status.h"
#include "pw_status/status_with_size.h"
namespace pw::data_link {
namespace {
const char* kLinkStateNameOpen = "Open";
const char* kLinkStateNameOpenRequest = "Open Request";
const char* kLinkStateNameClosed = "Closed";
const char* kLinkStateNameWaitingForOpen = "Waiting For Open";
const char* kLinkStateNameUnknown = "Unknown";
const auto kEpollTimeout = std::chrono::seconds(1);
// Configures the file descriptor as non blocking. Returns true when successful.
bool MakeSocketNonBlocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
return false;
}
flags += O_NONBLOCK;
if (fcntl(fd, F_SETFL, flags) == -1) {
PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno));
return false;
}
return true;
}
} // namespace
SocketDataLink::SocketDataLink(const char* host, uint16_t port) : host_(host) {
PW_CHECK(host_ != nullptr);
PW_CHECK(ToString(port, port_).ok());
}
SocketDataLink::SocketDataLink(int connection_fd,
EventHandlerCallback&& event_handler,
allocator::Allocator& write_buffer_allocator)
: connection_fd_(connection_fd),
write_buffer_allocator_(&write_buffer_allocator) {
PW_DCHECK(connection_fd > 0);
PW_CHECK(MakeSocketNonBlocking(connection_fd_));
PW_CHECK(ConfigureEpoll());
event_handler_ = std::move(event_handler);
set_link_state(LinkState::kOpen);
write_state_ = WriteState::kIdle;
read_state_ = ReadState::kIdle;
event_handler_(DataLink::Event::kOpen, StatusWithSize());
}
SocketDataLink::~SocketDataLink() {
lock_.lock();
if (link_state_ != LinkState::kClosed) {
DoClose(/*notify_closed=*/true);
return;
}
lock_.unlock();
}
void SocketDataLink::set_link_state(LinkState new_state) {
const char* link_name;
switch (link_state_) {
case LinkState::kOpen:
link_name = kLinkStateNameOpen;
break;
case LinkState::kOpenRequest:
link_name = kLinkStateNameOpenRequest;
break;
case LinkState::kClosed:
link_name = kLinkStateNameClosed;
break;
case LinkState::kWaitingForOpen:
link_name = kLinkStateNameWaitingForOpen;
break;
default:
link_name = kLinkStateNameUnknown;
};
const char* new_link_name;
switch (new_state) {
case LinkState::kOpen:
new_link_name = kLinkStateNameOpen;
break;
case LinkState::kOpenRequest:
new_link_name = kLinkStateNameOpenRequest;
break;
case LinkState::kClosed:
new_link_name = kLinkStateNameClosed;
break;
case LinkState::kWaitingForOpen:
new_link_name = kLinkStateNameWaitingForOpen;
break;
default:
new_link_name = kLinkStateNameUnknown;
};
PW_LOG_DEBUG("Transitioning from %s to %s", link_name, new_link_name);
link_state_ = new_state;
}
void SocketDataLink::Open(EventHandlerCallback&& event_handler,
allocator::Allocator& write_buffer_allocator) {
std::lock_guard lock(lock_);
PW_CHECK(link_state_ == LinkState::kClosed);
write_buffer_allocator_ = &write_buffer_allocator;
event_handler_ = std::move(event_handler);
set_link_state(LinkState::kOpenRequest);
}
void SocketDataLink::WaitAndConsumeEvents() {
// Manually lock and unlock, since some functions may perform unlocking before
// calling the user's event callback, when locks cannot be held.
lock_.lock();
switch (link_state_) {
case LinkState::kOpen:
break;
case LinkState::kWaitingForOpen: {
// Copy the epoll file descriptor to reduce the critical section.
const int fd = epoll_fd_;
lock_.unlock();
epoll_event event;
const int count = epoll_wait(
fd, &event, 1, std::chrono::milliseconds(kEpollTimeout).count());
if (count == 0) {
return;
}
if (event.events & EPOLLERR || event.events & EPOLLHUP) {
lock_.lock();
// Check if link was closed on another thread.
if (link_state_ != LinkState::kClosed) {
DoClose(/*notify_closed=*/false);
} else {
lock_.unlock();
}
event_handler_(DataLink::Event::kOpen, StatusWithSize::Unknown());
return;
}
if (event.events == EPOLLOUT) {
{
std::lock_guard lock(lock_);
set_link_state(LinkState::kOpen);
}
{
std::lock_guard lock(write_lock_);
write_state_ = WriteState::kIdle;
}
{
std::lock_guard lock(read_lock_);
read_state_ = ReadState::kIdle;
}
event_handler_(DataLink::Event::kOpen, StatusWithSize());
return;
}
PW_LOG_ERROR("Unhandled event %d while waiting to open link",
static_cast<int>(event.events));
}
return;
case LinkState::kClosed:
lock_.unlock();
return;
case LinkState::kOpenRequest:
DoOpen();
return;
}
// Copy the epoll file descriptor to reduce the critical section.
const int fd = epoll_fd_;
lock_.unlock();
epoll_event event;
const int count = epoll_wait(
fd, &event, 1, std::chrono::milliseconds(kEpollTimeout).count());
if (count == 0) {
return;
}
if (event.events & EPOLLERR || event.events & EPOLLHUP) {
lock_.lock();
// Check if link was closed on another thread.
if (link_state_ != LinkState::kClosed) {
DoClose(/*notify_closed=*/true);
} else {
lock_.unlock();
}
return;
}
if (event.events & EPOLLIN) {
// Can read!
read_lock_.lock();
if (read_state_ == ReadState::kReadRequested) {
DoRead();
} else if (read_state_ == ReadState::kIdle) {
read_lock_.unlock();
event_handler_(DataLink::Event::kDataReceived, StatusWithSize());
} else {
read_lock_.unlock();
}
}
if (event.events & EPOLLOUT) {
// Can Write!
write_lock_.lock();
if (write_state_ == WriteState::kPending) {
DoWrite();
} else {
write_lock_.unlock();
}
}
if ((event.events & EPOLLIN) == 0 && (event.events & EPOLLOUT) == 0) {
PW_LOG_WARN("Unhandled event %d", static_cast<int>(event.events));
}
}
void SocketDataLink::DoOpen() {
addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_NUMERICSERV;
addrinfo* res;
if (getaddrinfo(host_, port_, &hints, &res) != 0) {
PW_LOG_ERROR("Failed to configure connection address for socket");
HandleOpenFailure(/*info=*/nullptr);
return;
}
addrinfo* rp;
for (rp = res; rp != nullptr; rp = rp->ai_next) {
PW_LOG_DEBUG("Opening socket");
connection_fd_ = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (connection_fd_ != kInvalidFd) {
break;
}
}
if (connection_fd_ == kInvalidFd) {
HandleOpenFailure(res);
return;
}
// Set necessary options on a socket file descriptor.
PW_LOG_DEBUG("Configuring socket");
if (!MakeSocketNonBlocking(connection_fd_)) {
HandleOpenFailure(res);
return;
}
PW_LOG_DEBUG("Connecting socket");
if (connect(connection_fd_, rp->ai_addr, rp->ai_addrlen) == -1) {
if (errno != EINPROGRESS) {
HandleOpenFailure(res);
return;
}
set_link_state(LinkState::kWaitingForOpen);
} else {
set_link_state(LinkState::kOpen);
}
PW_LOG_DEBUG("Configuring Epoll");
if (!ConfigureEpoll()) {
HandleOpenFailure(res);
return;
}
const bool open_completed = link_state_ == LinkState::kOpen;
lock_.unlock();
freeaddrinfo(res);
if (open_completed) {
{
std::lock_guard lock(write_lock_);
write_state_ = WriteState::kIdle;
}
{
std::lock_guard lock(read_lock_);
read_state_ = ReadState::kIdle;
}
event_handler_(DataLink::Event::kOpen, StatusWithSize());
}
}
void SocketDataLink::HandleOpenFailure(addrinfo* info) {
if (connection_fd_ != kInvalidFd) {
close(connection_fd_);
connection_fd_ = kInvalidFd;
}
if (epoll_fd_ != kInvalidFd) {
close(epoll_fd_);
epoll_fd_ = kInvalidFd;
}
set_link_state(LinkState::kClosed);
lock_.unlock();
PW_LOG_ERROR(
"Failed to connect to %s:%s: %s", host_, port_, std::strerror(errno));
if (info != nullptr) {
freeaddrinfo(info);
}
event_handler_(DataLink::Event::kOpen, StatusWithSize::Unknown());
}
bool SocketDataLink::ConfigureEpoll() {
epoll_fd_ = epoll_create1(0);
if (epoll_fd_ == kInvalidFd) {
return false;
}
epoll_event_.events = EPOLLOUT | EPOLLIN;
epoll_event_.data.fd = connection_fd_;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, connection_fd_, &epoll_event_) ==
-1) {
return false;
}
return true;
}
void SocketDataLink::Close() {
lock_.lock();
PW_DCHECK(link_state_ != LinkState::kClosed);
DoClose(/*notify_closed=*/true);
}
void SocketDataLink::DoClose(bool notify_closed) {
set_link_state(LinkState::kClosed);
// Copy file descriptors and unlock to minimize the critical section.
const int connection_fd = connection_fd_;
connection_fd_ = kInvalidFd;
const int epoll_fd = epoll_fd_;
epoll_fd_ = kInvalidFd;
lock_.unlock();
{
std::lock_guard lock(write_lock_);
write_state_ = WriteState::kClosed;
}
{
std::lock_guard lock(read_lock_);
read_state_ = ReadState::kClosed;
}
// Close file descriptors if valid.
if (connection_fd != kInvalidFd) {
close(connection_fd);
}
if (epoll_fd != kInvalidFd) {
close(epoll_fd);
}
if (notify_closed) {
event_handler_(DataLink::Event::kClosed, StatusWithSize());
}
}
std::optional<multibuf::MultiBuf> SocketDataLink::GetWriteBuffer(size_t size) {
if (size == 0) {
return multibuf::MultiBuf();
}
{
std::lock_guard lock(write_lock_);
if (write_state_ != WriteState::kIdle) {
return std::nullopt;
}
}
multibuf::MultiBuf buffer;
std::optional<multibuf::OwnedChunk> chunk =
multibuf::HeaderChunkRegionTracker::AllocateRegionAsChunk(
write_buffer_allocator_, size);
if (!chunk) {
return std::nullopt;
}
buffer.PushFrontChunk(std::move(*chunk));
return buffer;
}
Status SocketDataLink::Write(multibuf::MultiBuf&& buffer) {
if (buffer.size() == 0) {
return Status::InvalidArgument();
}
std::lock_guard lock(write_lock_);
if (write_state_ != WriteState::kIdle) {
return Status::FailedPrecondition();
}
tx_multibuf_ = std::move(buffer);
num_bytes_to_send_ = tx_multibuf_.size();
num_bytes_sent_ = 0;
write_state_ = WriteState::kPending;
return OkStatus();
}
void SocketDataLink::DoWrite() {
int send_flags = 0;
#if defined(__linux__)
// Use MSG_NOSIGNAL to avoid getting a SIGPIPE signal when the remote
// peer drops the connection. This is supported on Linux only.
send_flags |= MSG_NOSIGNAL;
#endif // defined(__linux__)
auto [chunk_iter, chunk] =
tx_multibuf_.TakeChunk(tx_multibuf_.Chunks().begin());
ssize_t bytes_sent;
{
std::lock_guard lock(lock_);
// TODO: Send one chunk at a time.
bytes_sent = send(connection_fd_,
reinterpret_cast<const char*>(chunk.data()),
chunk.size(),
send_flags);
}
// Check for errors.
if (bytes_sent < 0) {
if (errno == EPIPE) {
// An EPIPE indicates that the connection is closed.
tx_multibuf_.Release();
write_lock_.unlock();
event_handler_(DataLink::Event::kDataSent, StatusWithSize::OutOfRange());
Close();
return;
}
tx_multibuf_.Release();
write_lock_.unlock();
event_handler_(DataLink::Event::kDataSent, StatusWithSize::Unknown());
return;
}
num_bytes_sent_ += bytes_sent;
// Resize chunk if it was a partial write.
if (static_cast<size_t>(bytes_sent) < chunk.size()) {
chunk->DiscardPrefix(static_cast<size_t>(bytes_sent));
tx_multibuf_.PushFrontChunk(std::move(chunk));
write_lock_.unlock();
return;
}
// Check if we are done with the MultiBuf.
if (num_bytes_sent_ >= num_bytes_to_send_) {
write_state_ = WriteState::kIdle;
tx_multibuf_.Release();
write_lock_.unlock();
event_handler_(DataLink::Event::kDataSent,
StatusWithSize(num_bytes_to_send_));
return;
}
write_lock_.unlock();
}
Status SocketDataLink::Read(ByteSpan buffer) {
PW_DCHECK(buffer.size() > 0);
std::lock_guard lock(read_lock_);
if (read_state_ != ReadState::kIdle) {
return Status::FailedPrecondition();
}
rx_buffer_ = buffer;
read_state_ = ReadState::kReadRequested;
return OkStatus();
}
void SocketDataLink::DoRead() {
ssize_t bytes_rcvd;
{
std::lock_guard lock(lock_);
bytes_rcvd = recv(connection_fd_,
reinterpret_cast<char*>(rx_buffer_.data()),
rx_buffer_.size_bytes(),
0);
}
if (bytes_rcvd == 0) {
// Remote peer has closed the connection.
read_state_ = ReadState::kClosed;
read_lock_.unlock();
event_handler_(DataLink::Event::kDataRead, StatusWithSize::Internal());
Close();
return;
} else if (bytes_rcvd < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// Socket timed out when trying to read.
// This should only occur if SO_RCVTIMEO was configured to be nonzero, or
// if the socket was opened with the O_NONBLOCK flag to prevent any
// blocking when performing reads or writes.
read_state_ = ReadState::kIdle;
read_lock_.unlock();
event_handler_(DataLink::Event::kDataRead,
StatusWithSize::ResourceExhausted());
return;
}
read_state_ = ReadState::kIdle;
read_lock_.unlock();
event_handler_(DataLink::Event::kDataRead, StatusWithSize::Unknown());
return;
}
read_state_ = ReadState::kIdle;
read_lock_.unlock();
event_handler_(DataLink::Event::kDataRead, StatusWithSize(bytes_rcvd));
}
} // namespace pw::data_link