pw_stream: Create ServerSocket class
This commit adds a new ServerSocket class that aims to improve on the
shortcomings of SocketStream (when used as a server). The class:
1) Supports multiple clients
2) Splits the Listen and Accept operations (the old Serve method was a
source of race conditions)
3) Allows for random port assignment
A user will now create a single ServerSocket, call `Listen` once, and
`Accept` as many times as needed, each call creates a new SocketStream
for communicating with that client.
Due to the improvements in ServerSocket, it and SocketStream can now be
tested reliably, so tests have been added.
Bug: b/271321951
Change-Id: Ic5cd42eb0e36169ed18891cc06d0704867cb0254
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/131611
Commit-Queue: Eli Lipsitz <elipsitz@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel
index c4d910a..edd2f4b 100644
--- a/pw_stream/BUILD.bazel
+++ b/pw_stream/BUILD.bazel
@@ -143,3 +143,12 @@
"//pw_unit_test",
],
)
+
+pw_cc_test(
+ name = "socket_stream_test",
+ srcs = ["socket_stream_test.cc"],
+ deps = [
+ ":socket_stream",
+ "//pw_unit_test",
+ ],
+)
diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn
index 1ac95a3..5f9e27b 100644
--- a/pw_stream/BUILD.gn
+++ b/pw_stream/BUILD.gn
@@ -95,6 +95,11 @@
if (defined(pw_toolchain_SCOPE.is_host_toolchain) &&
pw_toolchain_SCOPE.is_host_toolchain) {
tests += [ ":std_file_stream_test" ]
+
+ # socket_stream_test only compiles on Linux.
+ if (host_os == "linux") {
+ tests += [ ":socket_stream_test" ]
+ }
}
}
@@ -136,3 +141,8 @@
sources = [ "interval_reader_test.cc" ]
deps = [ ":interval_reader" ]
}
+
+pw_test("socket_stream_test") {
+ sources = [ "socket_stream_test.cc" ]
+ deps = [ ":socket_stream" ]
+}
diff --git a/pw_stream/docs.rst b/pw_stream/docs.rst
index dd4fef1..c4484ae 100644
--- a/pw_stream/docs.rst
+++ b/pw_stream/docs.rst
@@ -426,8 +426,14 @@
.. cpp:class:: SocketStream : public NonSeekableReaderWriter
- ``SocketStream`` wraps posix-style sockets with the :cpp:class:`Reader` and
- :cpp:class:`Writer` interfaces.
+ ``SocketStream`` wraps posix-style TCP sockets with the :cpp:class:`Reader`
+ and :cpp:class:`Writer` interfaces. It can be used to connect to a TCP server,
+ or to communicate with a client via the ``ServerSocket`` class.
+
+.. cpp:class:: ServerSocket
+
+ ``ServerSocket`` wraps a posix server socket, and produces a
+ :cpp:class:`SocketStream` for each accepted client connection.
------------------
Why use pw_stream?
diff --git a/pw_stream/public/pw_stream/socket_stream.h b/pw_stream/public/pw_stream/socket_stream.h
index d757822..d6a8f3a 100644
--- a/pw_stream/public/pw_stream/socket_stream.h
+++ b/pw_stream/public/pw_stream/socket_stream.h
@@ -17,6 +17,7 @@
#include <cstdint>
+#include "pw_result/result.h"
#include "pw_span/span.h"
#include "pw_stream/stream.h"
@@ -26,9 +27,33 @@
public:
constexpr SocketStream() = default;
+ // SocketStream objects are moveable but not copyable.
+ SocketStream& operator=(SocketStream&& other) {
+ listen_port_ = other.listen_port_;
+ socket_fd_ = other.socket_fd_;
+ other.socket_fd_ = kInvalidFd;
+ connection_fd_ = other.connection_fd_;
+ other.connection_fd_ = kInvalidFd;
+ sockaddr_client_ = other.sockaddr_client_;
+ return *this;
+ }
+ SocketStream(SocketStream&& other) noexcept
+ : listen_port_(other.listen_port_),
+ socket_fd_(other.socket_fd_),
+ connection_fd_(other.connection_fd_),
+ sockaddr_client_(other.sockaddr_client_) {
+ other.socket_fd_ = kInvalidFd;
+ other.connection_fd_ = kInvalidFd;
+ }
+ SocketStream(const SocketStream&) = delete;
+ SocketStream& operator=(const SocketStream&) = delete;
+
~SocketStream() override { Close(); }
// Listen to the port and return after a client is connected
+ //
+ // DEPRECATED: Use the ServerSocket class instead.
+ // TODO(b/271323032): Remove when this method is no longer used.
Status Serve(uint16_t port);
// Connect to a local or remote endpoint. Host must be an IPv4 address. If
@@ -46,6 +71,8 @@
int connection_fd() { return connection_fd_; }
private:
+ friend class ServerSocket;
+
static constexpr int kInvalidFd = -1;
Status DoWrite(span<const std::byte> data) override;
@@ -58,4 +85,39 @@
struct sockaddr_in sockaddr_client_ = {};
};
+/// `ServerSocket` wraps a POSIX-style server socket, producing a `SocketStream`
+/// for each accepted client connection.
+///
+/// Call `Listen` to create the socket and start listening for connections.
+/// Then call `Accept` any number of times to accept client connections.
+class ServerSocket {
+ public:
+ ServerSocket() = default;
+ ~ServerSocket() { Close(); }
+
+ ServerSocket(const ServerSocket& other) = delete;
+ ServerSocket& operator=(const ServerSocket& other) = delete;
+
+ // Listen for connections on the given port.
+ // If port is 0, a random unused port is chosen and can be retrieved with
+ // port().
+ Status Listen(uint16_t port = 0);
+
+ // Accept a connection. Blocks until after a client is connected.
+ // On success, returns a SocketStream connected to the new client.
+ Result<SocketStream> Accept();
+
+ // Close the server socket, preventing further connections.
+ void Close();
+
+ // Returns the port this socket is listening on.
+ uint16_t port() const { return port_; }
+
+ private:
+ static constexpr int kInvalidFd = -1;
+
+ uint16_t port_ = -1;
+ int socket_fd_ = kInvalidFd;
+};
+
} // namespace pw::stream
diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc
index 564857b..6dfec6c 100644
--- a/pw_stream/socket_stream.cc
+++ b/pw_stream/socket_stream.cc
@@ -24,7 +24,7 @@
namespace pw::stream {
namespace {
-constexpr uint32_t kMaxConcurrentUser = 1;
+constexpr uint32_t kServerBacklogLength = 1;
constexpr const char* kLocalhostAddress = "127.0.0.1";
} // namespace
@@ -65,7 +65,7 @@
return Status::Unknown();
}
- if (listen(socket_fd_, kMaxConcurrentUser) < 0) {
+ if (listen(socket_fd_, kServerBacklogLength) < 0) {
PW_LOG_ERROR("Failed to listen to socket: %s", std::strerror(errno));
return Status::Unknown();
}
@@ -156,4 +156,73 @@
return StatusWithSize(bytes_rcvd);
}
+// Listen for connections on the given port.
+// If port is 0, a random unused port is chosen and can be retrieved with
+// port().
+Status ServerSocket::Listen(uint16_t port) {
+ socket_fd_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (socket_fd_ == kInvalidFd) {
+ return Status::Unknown();
+ }
+
+ if (port != 0) {
+ // When specifying an explicit port, allow binding to an address that may
+ // still be in use by a closed socket.
+ constexpr int value = 1;
+ setsockopt(socket_fd_, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));
+
+ struct sockaddr_in addr = {};
+ socklen_t addr_len = sizeof(addr);
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ addr.sin_addr.s_addr = INADDR_ANY;
+ if (bind(socket_fd_, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
+ return Status::Unknown();
+ }
+ }
+
+ if (listen(socket_fd_, kServerBacklogLength) < 0) {
+ return Status::Unknown();
+ }
+
+ // Find out which port the socket is listening on, and fill in port_.
+ struct sockaddr_in addr = {};
+ socklen_t addr_len = sizeof(addr);
+ if (getsockname(socket_fd_, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
+ 0 ||
+ addr_len > sizeof(addr)) {
+ close(socket_fd_);
+ return Status::Unknown();
+ }
+
+ port_ = ntohs(addr.sin_port);
+
+ return OkStatus();
+}
+
+// Accept a connection. Blocks until after a client is connected.
+// On success, returns a SocketStream connected to the new client.
+Result<SocketStream> ServerSocket::Accept() {
+ struct sockaddr_in sockaddr_client_ = {};
+ socklen_t len = sizeof(sockaddr_client_);
+
+ int connection_fd =
+ accept(socket_fd_, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
+ if (connection_fd == kInvalidFd) {
+ return Status::Unknown();
+ }
+
+ SocketStream client_stream;
+ client_stream.connection_fd_ = connection_fd;
+ return client_stream;
+}
+
+// Close the server socket, preventing further connections.
+void ServerSocket::Close() {
+ if (socket_fd_ != kInvalidFd) {
+ close(socket_fd_);
+ socket_fd_ = kInvalidFd;
+ }
+}
+
} // namespace pw::stream
diff --git a/pw_stream/socket_stream_test.cc b/pw_stream/socket_stream_test.cc
new file mode 100644
index 0000000..283bae4
--- /dev/null
+++ b/pw_stream/socket_stream_test.cc
@@ -0,0 +1,166 @@
+// Copyright 2023 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "pw_stream/socket_stream.h"
+
+#include <thread>
+
+#include "gtest/gtest.h"
+#include "pw_result/result.h"
+#include "pw_status/status.h"
+
+namespace pw::stream {
+namespace {
+
+TEST(SocketStreamTest, Connect) {
+ ServerSocket server;
+ EXPECT_EQ(server.Listen(), OkStatus());
+
+ Result<SocketStream> server_stream = Status::Unavailable();
+ auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+ SocketStream client;
+ EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+ accept_thread.join();
+ EXPECT_EQ(server_stream.status(), OkStatus());
+
+ server_stream.value().Close();
+ server.Close();
+ client.Close();
+}
+
+TEST(SocketStreamTest, ConnectSpecificPort) {
+ // We want to test the "listen on a specific port" functionality,
+ // but hard-coding a port number in a test is inherently problematic, as
+ // port numbers are a global resource.
+ //
+ // We use the automatic port assignment initially to get a port assignment,
+ // close that server, and then use that port explicitly in a new server.
+ //
+ // There's still the possibility that the port will get swiped, but it
+ // shouldn't happen by chance.
+ ServerSocket initial_server;
+ EXPECT_EQ(initial_server.Listen(), OkStatus());
+ uint16_t port = initial_server.port();
+ initial_server.Close();
+
+ ServerSocket server;
+ EXPECT_EQ(server.Listen(port), OkStatus());
+ EXPECT_EQ(server.port(), port);
+
+ Result<SocketStream> server_stream = Status::Unavailable();
+ auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+ SocketStream client;
+ EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+ accept_thread.join();
+ EXPECT_EQ(server_stream.status(), OkStatus());
+
+ server_stream.value().Close();
+ server.Close();
+ client.Close();
+}
+
+// Helper function to test exchanging data on a pair of sockets.
+void ExchangeData(SocketStream& stream1, SocketStream& stream2) {
+ auto kPayload1 = as_bytes(span("some data"));
+ auto kPayload2 = as_bytes(span("other bytes"));
+ std::array<char, 100> read_buffer{};
+
+ // Write data from stream1 and read it from stream2.
+ auto write_status = Status::Unavailable();
+ auto write_thread =
+ std::thread{[&]() { write_status = stream1.Write(kPayload1); }};
+ Result<ByteSpan> read_result =
+ stream2.Read(as_writable_bytes(span(read_buffer)));
+ EXPECT_EQ(read_result.status(), OkStatus());
+ EXPECT_EQ(read_result.value().size(), kPayload1.size());
+ EXPECT_TRUE(
+ std::equal(kPayload1.begin(), kPayload1.end(), read_result->begin()));
+
+ write_thread.join();
+ EXPECT_EQ(write_status, OkStatus());
+
+ // Read data in the client and write it from the server.
+ auto read_thread = std::thread{[&]() {
+ read_result = stream1.Read(as_writable_bytes(span(read_buffer)));
+ }};
+ EXPECT_EQ(stream2.Write(kPayload2), OkStatus());
+
+ read_thread.join();
+ EXPECT_EQ(read_result.status(), OkStatus());
+ EXPECT_EQ(read_result.value().size(), kPayload2.size());
+ EXPECT_TRUE(
+ std::equal(kPayload2.begin(), kPayload2.end(), read_result->begin()));
+
+ // Close stream1 and attempt to read from stream2.
+ stream1.Close();
+ read_result = stream2.Read(as_writable_bytes(span(read_buffer)));
+ EXPECT_EQ(read_result.status(), Status::OutOfRange());
+
+ stream2.Close();
+}
+
+TEST(SocketStreamTest, ReadWrite) {
+ ServerSocket server;
+ EXPECT_EQ(server.Listen(), OkStatus());
+
+ Result<SocketStream> server_stream = Status::Unavailable();
+ auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+ SocketStream client;
+ EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+ accept_thread.join();
+ EXPECT_EQ(server_stream.status(), OkStatus());
+
+ ExchangeData(client, server_stream.value());
+ server.Close();
+}
+
+TEST(SocketStreamTest, MultipleClients) {
+ ServerSocket server;
+ EXPECT_EQ(server.Listen(), OkStatus());
+
+ Result<SocketStream> server_stream1 = Status::Unavailable();
+ Result<SocketStream> server_stream2 = Status::Unavailable();
+ Result<SocketStream> server_stream3 = Status::Unavailable();
+ auto accept_thread = std::thread{[&]() {
+ server_stream1 = server.Accept();
+ server_stream2 = server.Accept();
+ server_stream3 = server.Accept();
+ }};
+
+ SocketStream client1;
+ SocketStream client2;
+ SocketStream client3;
+ EXPECT_EQ(client1.Connect("localhost", server.port()), OkStatus());
+ EXPECT_EQ(client2.Connect("localhost", server.port()), OkStatus());
+ EXPECT_EQ(client3.Connect("localhost", server.port()), OkStatus());
+
+ accept_thread.join();
+ EXPECT_EQ(server_stream1.status(), OkStatus());
+ EXPECT_EQ(server_stream2.status(), OkStatus());
+ EXPECT_EQ(server_stream3.status(), OkStatus());
+
+ ExchangeData(client1, server_stream1.value());
+ ExchangeData(client2, server_stream2.value());
+ ExchangeData(client3, server_stream3.value());
+ server.Close();
+}
+
+} // namespace
+} // namespace pw::stream