pw_stream: Create ServerSocket class

This commit adds a new ServerSocket class that aims to improve on the
shortcomings of SocketStream (when used as a server). The class:
1) Supports multiple clients
2) Splits the Listen and Accept operations (the old Serve method was a
   source of race conditions)
3) Allows for random port assignment

A user will now create a single ServerSocket, call `Listen` once, and
`Accept` as many times as needed, each call creates a new SocketStream
for communicating with that client.

Due to the improvements in ServerSocket, it and SocketStream can now be
tested reliably, so tests have been added.

Bug: b/271321951
Change-Id: Ic5cd42eb0e36169ed18891cc06d0704867cb0254
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/131611
Commit-Queue: Eli Lipsitz <elipsitz@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel
index c4d910a..edd2f4b 100644
--- a/pw_stream/BUILD.bazel
+++ b/pw_stream/BUILD.bazel
@@ -143,3 +143,12 @@
         "//pw_unit_test",
     ],
 )
+
+pw_cc_test(
+    name = "socket_stream_test",
+    srcs = ["socket_stream_test.cc"],
+    deps = [
+        ":socket_stream",
+        "//pw_unit_test",
+    ],
+)
diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn
index 1ac95a3..5f9e27b 100644
--- a/pw_stream/BUILD.gn
+++ b/pw_stream/BUILD.gn
@@ -95,6 +95,11 @@
   if (defined(pw_toolchain_SCOPE.is_host_toolchain) &&
       pw_toolchain_SCOPE.is_host_toolchain) {
     tests += [ ":std_file_stream_test" ]
+
+    # socket_stream_test only compiles on Linux.
+    if (host_os == "linux") {
+      tests += [ ":socket_stream_test" ]
+    }
   }
 }
 
@@ -136,3 +141,8 @@
   sources = [ "interval_reader_test.cc" ]
   deps = [ ":interval_reader" ]
 }
+
+pw_test("socket_stream_test") {
+  sources = [ "socket_stream_test.cc" ]
+  deps = [ ":socket_stream" ]
+}
diff --git a/pw_stream/docs.rst b/pw_stream/docs.rst
index dd4fef1..c4484ae 100644
--- a/pw_stream/docs.rst
+++ b/pw_stream/docs.rst
@@ -426,8 +426,14 @@
 
 .. cpp:class:: SocketStream : public NonSeekableReaderWriter
 
-  ``SocketStream`` wraps posix-style sockets with the :cpp:class:`Reader` and
-  :cpp:class:`Writer` interfaces.
+  ``SocketStream`` wraps posix-style TCP sockets with the :cpp:class:`Reader`
+  and :cpp:class:`Writer` interfaces. It can be used to connect to a TCP server,
+  or to communicate with a client via the ``ServerSocket`` class.
+
+.. cpp:class:: ServerSocket
+
+  ``ServerSocket`` wraps a posix server socket, and produces a
+  :cpp:class:`SocketStream` for each accepted client connection.
 
 ------------------
 Why use pw_stream?
diff --git a/pw_stream/public/pw_stream/socket_stream.h b/pw_stream/public/pw_stream/socket_stream.h
index d757822..d6a8f3a 100644
--- a/pw_stream/public/pw_stream/socket_stream.h
+++ b/pw_stream/public/pw_stream/socket_stream.h
@@ -17,6 +17,7 @@
 
 #include <cstdint>
 
+#include "pw_result/result.h"
 #include "pw_span/span.h"
 #include "pw_stream/stream.h"
 
@@ -26,9 +27,33 @@
  public:
   constexpr SocketStream() = default;
 
+  // SocketStream objects are moveable but not copyable.
+  SocketStream& operator=(SocketStream&& other) {
+    listen_port_ = other.listen_port_;
+    socket_fd_ = other.socket_fd_;
+    other.socket_fd_ = kInvalidFd;
+    connection_fd_ = other.connection_fd_;
+    other.connection_fd_ = kInvalidFd;
+    sockaddr_client_ = other.sockaddr_client_;
+    return *this;
+  }
+  SocketStream(SocketStream&& other) noexcept
+      : listen_port_(other.listen_port_),
+        socket_fd_(other.socket_fd_),
+        connection_fd_(other.connection_fd_),
+        sockaddr_client_(other.sockaddr_client_) {
+    other.socket_fd_ = kInvalidFd;
+    other.connection_fd_ = kInvalidFd;
+  }
+  SocketStream(const SocketStream&) = delete;
+  SocketStream& operator=(const SocketStream&) = delete;
+
   ~SocketStream() override { Close(); }
 
   // Listen to the port and return after a client is connected
+  //
+  // DEPRECATED: Use the ServerSocket class instead.
+  // TODO(b/271323032): Remove when this method is no longer used.
   Status Serve(uint16_t port);
 
   // Connect to a local or remote endpoint. Host must be an IPv4 address. If
@@ -46,6 +71,8 @@
   int connection_fd() { return connection_fd_; }
 
  private:
+  friend class ServerSocket;
+
   static constexpr int kInvalidFd = -1;
 
   Status DoWrite(span<const std::byte> data) override;
@@ -58,4 +85,39 @@
   struct sockaddr_in sockaddr_client_ = {};
 };
 
+/// `ServerSocket` wraps a POSIX-style server socket, producing a `SocketStream`
+/// for each accepted client connection.
+///
+/// Call `Listen` to create the socket and start listening for connections.
+/// Then call `Accept` any number of times to accept client connections.
+class ServerSocket {
+ public:
+  ServerSocket() = default;
+  ~ServerSocket() { Close(); }
+
+  ServerSocket(const ServerSocket& other) = delete;
+  ServerSocket& operator=(const ServerSocket& other) = delete;
+
+  // Listen for connections on the given port.
+  // If port is 0, a random unused port is chosen and can be retrieved with
+  // port().
+  Status Listen(uint16_t port = 0);
+
+  // Accept a connection. Blocks until after a client is connected.
+  // On success, returns a SocketStream connected to the new client.
+  Result<SocketStream> Accept();
+
+  // Close the server socket, preventing further connections.
+  void Close();
+
+  // Returns the port this socket is listening on.
+  uint16_t port() const { return port_; }
+
+ private:
+  static constexpr int kInvalidFd = -1;
+
+  uint16_t port_ = -1;
+  int socket_fd_ = kInvalidFd;
+};
+
 }  // namespace pw::stream
diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc
index 564857b..6dfec6c 100644
--- a/pw_stream/socket_stream.cc
+++ b/pw_stream/socket_stream.cc
@@ -24,7 +24,7 @@
 namespace pw::stream {
 namespace {
 
-constexpr uint32_t kMaxConcurrentUser = 1;
+constexpr uint32_t kServerBacklogLength = 1;
 constexpr const char* kLocalhostAddress = "127.0.0.1";
 
 }  // namespace
@@ -65,7 +65,7 @@
     return Status::Unknown();
   }
 
-  if (listen(socket_fd_, kMaxConcurrentUser) < 0) {
+  if (listen(socket_fd_, kServerBacklogLength) < 0) {
     PW_LOG_ERROR("Failed to listen to socket: %s", std::strerror(errno));
     return Status::Unknown();
   }
@@ -156,4 +156,73 @@
   return StatusWithSize(bytes_rcvd);
 }
 
+// Listen for connections on the given port.
+// If port is 0, a random unused port is chosen and can be retrieved with
+// port().
+Status ServerSocket::Listen(uint16_t port) {
+  socket_fd_ = socket(AF_INET, SOCK_STREAM, 0);
+  if (socket_fd_ == kInvalidFd) {
+    return Status::Unknown();
+  }
+
+  if (port != 0) {
+    // When specifying an explicit port, allow binding to an address that may
+    // still be in use by a closed socket.
+    constexpr int value = 1;
+    setsockopt(socket_fd_, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));
+
+    struct sockaddr_in addr = {};
+    socklen_t addr_len = sizeof(addr);
+    addr.sin_family = AF_INET;
+    addr.sin_port = htons(port);
+    addr.sin_addr.s_addr = INADDR_ANY;
+    if (bind(socket_fd_, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
+      return Status::Unknown();
+    }
+  }
+
+  if (listen(socket_fd_, kServerBacklogLength) < 0) {
+    return Status::Unknown();
+  }
+
+  // Find out which port the socket is listening on, and fill in port_.
+  struct sockaddr_in addr = {};
+  socklen_t addr_len = sizeof(addr);
+  if (getsockname(socket_fd_, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
+          0 ||
+      addr_len > sizeof(addr)) {
+    close(socket_fd_);
+    return Status::Unknown();
+  }
+
+  port_ = ntohs(addr.sin_port);
+
+  return OkStatus();
+}
+
+// Accept a connection. Blocks until after a client is connected.
+// On success, returns a SocketStream connected to the new client.
+Result<SocketStream> ServerSocket::Accept() {
+  struct sockaddr_in sockaddr_client_ = {};
+  socklen_t len = sizeof(sockaddr_client_);
+
+  int connection_fd =
+      accept(socket_fd_, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
+  if (connection_fd == kInvalidFd) {
+    return Status::Unknown();
+  }
+
+  SocketStream client_stream;
+  client_stream.connection_fd_ = connection_fd;
+  return client_stream;
+}
+
+// Close the server socket, preventing further connections.
+void ServerSocket::Close() {
+  if (socket_fd_ != kInvalidFd) {
+    close(socket_fd_);
+    socket_fd_ = kInvalidFd;
+  }
+}
+
 }  // namespace pw::stream
diff --git a/pw_stream/socket_stream_test.cc b/pw_stream/socket_stream_test.cc
new file mode 100644
index 0000000..283bae4
--- /dev/null
+++ b/pw_stream/socket_stream_test.cc
@@ -0,0 +1,166 @@
+// Copyright 2023 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+#include "pw_stream/socket_stream.h"
+
+#include <thread>
+
+#include "gtest/gtest.h"
+#include "pw_result/result.h"
+#include "pw_status/status.h"
+
+namespace pw::stream {
+namespace {
+
+TEST(SocketStreamTest, Connect) {
+  ServerSocket server;
+  EXPECT_EQ(server.Listen(), OkStatus());
+
+  Result<SocketStream> server_stream = Status::Unavailable();
+  auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+  SocketStream client;
+  EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+  accept_thread.join();
+  EXPECT_EQ(server_stream.status(), OkStatus());
+
+  server_stream.value().Close();
+  server.Close();
+  client.Close();
+}
+
+TEST(SocketStreamTest, ConnectSpecificPort) {
+  // We want to test the "listen on a specific port" functionality,
+  // but hard-coding a port number in a test is inherently problematic, as
+  // port numbers are a global resource.
+  //
+  // We use the automatic port assignment initially to get a port assignment,
+  // close that server, and then use that port explicitly in a new server.
+  //
+  // There's still the possibility that the port will get swiped, but it
+  // shouldn't happen by chance.
+  ServerSocket initial_server;
+  EXPECT_EQ(initial_server.Listen(), OkStatus());
+  uint16_t port = initial_server.port();
+  initial_server.Close();
+
+  ServerSocket server;
+  EXPECT_EQ(server.Listen(port), OkStatus());
+  EXPECT_EQ(server.port(), port);
+
+  Result<SocketStream> server_stream = Status::Unavailable();
+  auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+  SocketStream client;
+  EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+  accept_thread.join();
+  EXPECT_EQ(server_stream.status(), OkStatus());
+
+  server_stream.value().Close();
+  server.Close();
+  client.Close();
+}
+
+// Helper function to test exchanging data on a pair of sockets.
+void ExchangeData(SocketStream& stream1, SocketStream& stream2) {
+  auto kPayload1 = as_bytes(span("some data"));
+  auto kPayload2 = as_bytes(span("other bytes"));
+  std::array<char, 100> read_buffer{};
+
+  // Write data from stream1 and read it from stream2.
+  auto write_status = Status::Unavailable();
+  auto write_thread =
+      std::thread{[&]() { write_status = stream1.Write(kPayload1); }};
+  Result<ByteSpan> read_result =
+      stream2.Read(as_writable_bytes(span(read_buffer)));
+  EXPECT_EQ(read_result.status(), OkStatus());
+  EXPECT_EQ(read_result.value().size(), kPayload1.size());
+  EXPECT_TRUE(
+      std::equal(kPayload1.begin(), kPayload1.end(), read_result->begin()));
+
+  write_thread.join();
+  EXPECT_EQ(write_status, OkStatus());
+
+  // Read data in the client and write it from the server.
+  auto read_thread = std::thread{[&]() {
+    read_result = stream1.Read(as_writable_bytes(span(read_buffer)));
+  }};
+  EXPECT_EQ(stream2.Write(kPayload2), OkStatus());
+
+  read_thread.join();
+  EXPECT_EQ(read_result.status(), OkStatus());
+  EXPECT_EQ(read_result.value().size(), kPayload2.size());
+  EXPECT_TRUE(
+      std::equal(kPayload2.begin(), kPayload2.end(), read_result->begin()));
+
+  // Close stream1 and attempt to read from stream2.
+  stream1.Close();
+  read_result = stream2.Read(as_writable_bytes(span(read_buffer)));
+  EXPECT_EQ(read_result.status(), Status::OutOfRange());
+
+  stream2.Close();
+}
+
+TEST(SocketStreamTest, ReadWrite) {
+  ServerSocket server;
+  EXPECT_EQ(server.Listen(), OkStatus());
+
+  Result<SocketStream> server_stream = Status::Unavailable();
+  auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
+
+  SocketStream client;
+  EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
+
+  accept_thread.join();
+  EXPECT_EQ(server_stream.status(), OkStatus());
+
+  ExchangeData(client, server_stream.value());
+  server.Close();
+}
+
+TEST(SocketStreamTest, MultipleClients) {
+  ServerSocket server;
+  EXPECT_EQ(server.Listen(), OkStatus());
+
+  Result<SocketStream> server_stream1 = Status::Unavailable();
+  Result<SocketStream> server_stream2 = Status::Unavailable();
+  Result<SocketStream> server_stream3 = Status::Unavailable();
+  auto accept_thread = std::thread{[&]() {
+    server_stream1 = server.Accept();
+    server_stream2 = server.Accept();
+    server_stream3 = server.Accept();
+  }};
+
+  SocketStream client1;
+  SocketStream client2;
+  SocketStream client3;
+  EXPECT_EQ(client1.Connect("localhost", server.port()), OkStatus());
+  EXPECT_EQ(client2.Connect("localhost", server.port()), OkStatus());
+  EXPECT_EQ(client3.Connect("localhost", server.port()), OkStatus());
+
+  accept_thread.join();
+  EXPECT_EQ(server_stream1.status(), OkStatus());
+  EXPECT_EQ(server_stream2.status(), OkStatus());
+  EXPECT_EQ(server_stream3.status(), OkStatus());
+
+  ExchangeData(client1, server_stream1.value());
+  ExchangeData(client2, server_stream2.value());
+  ExchangeData(client3, server_stream3.value());
+  server.Close();
+}
+
+}  // namespace
+}  // namespace pw::stream