pw_console: Improve SocketClient addressing

Update SocketClient's addressing support to handle both ipv6 and ipv4
in addition to unix sockets.

Test: Successfully connected to localhost and unix socket.

Change-Id: I0330394dcb998db9822cd2a0fd654bc7d60cd6a4
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/178921
Reviewed-by: Carlos Chinchilla <cachinchilla@google.com>
Commit-Queue: Carlos Chinchilla <cachinchilla@google.com>
Pigweed-Auto-Submit: Carlos Chinchilla <cachinchilla@google.com>
Reviewed-by: Taylor Cramer <cramertj@google.com>
Reviewed-by: Tom Craig <tommycraig@gmail.com>
diff --git a/pw_console/embedding.rst b/pw_console/embedding.rst
index a5d896c..b2263af 100644
--- a/pw_console/embedding.rst
+++ b/pw_console/embedding.rst
@@ -113,7 +113,16 @@
 
    from pw_console.socket_client import SocketClientWithLogging
 
+   # Name resolution with explicit port
    serial_device = SocketClientWithLogging('localhost:1234')
+   # Name resolution with default port.
+   serial_device = SocketClientWithLogging('pigweed.dev')
+   # Link-local IPv6 address with explicit port.
+   serial_device = SocketClientWithLogging('[fe80::100%enp1s0]:1234')
+   # Link-local IPv6 address with default port.
+   serial_device = SocketClientWithLogging('[fe80::100%enp1s0]')
+   # IPv4 address with port.
+   serial_device = SocketClientWithLogging('1.2.3.4:5678')
 
 .. tip::
    The ``SocketClient`` takes an optional callback called when a disconnect is
diff --git a/pw_console/py/BUILD.bazel b/pw_console/py/BUILD.bazel
index 798b335..9fdeebd 100644
--- a/pw_console/py/BUILD.bazel
+++ b/pw_console/py/BUILD.bazel
@@ -201,6 +201,17 @@
 )
 
 py_test(
+    name = "socket_client_test",
+    size = "small",
+    srcs = [
+        "socket_client_test.py",
+    ],
+    deps = [
+        ":pw_console",
+    ],
+)
+
+py_test(
     name = "repl_pane_test",
     size = "small",
     srcs = [
diff --git a/pw_console/py/BUILD.gn b/pw_console/py/BUILD.gn
index 2cc6d18..3683004 100644
--- a/pw_console/py/BUILD.gn
+++ b/pw_console/py/BUILD.gn
@@ -87,6 +87,7 @@
     "log_store_test.py",
     "log_view_test.py",
     "repl_pane_test.py",
+    "socket_client_test.py",
     "table_test.py",
     "text_formatting_test.py",
     "window_manager_test.py",
diff --git a/pw_console/py/pw_console/socket_client.py b/pw_console/py/pw_console/socket_client.py
index 41a1c66..5344c31 100644
--- a/pw_console/py/pw_console/socket_client.py
+++ b/pw_console/py/pw_console/socket_client.py
@@ -17,6 +17,7 @@
 from typing import Callable, Optional, TYPE_CHECKING, Tuple, Union
 
 import errno
+import re
 import socket
 
 from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker
@@ -33,43 +34,125 @@
     DEFAULT_SOCKET_PORT = 33000
     PW_RPC_MAX_PACKET_SIZE = 256
 
+    _InitArgsType = Tuple[
+        socket.AddressFamily, int  # pylint: disable=no-member
+    ]
+    # Can be a string, (address, port) for AF_INET or (address, port, flowinfo,
+    # scope_id) AF_INET6.
+    _AddressType = Union[str, Tuple[str, int], Tuple[str, int, int, int]]
+
     def __init__(
         self,
         config: str,
         on_disconnect: Optional[Callable[[SocketClient], None]] = None,
     ):
-        self._connection_type: int
-        self._interface: Union[str, Tuple[str, int]]
-        if config == 'default':
-            self._connection_type = socket.AF_INET6
-            self._interface = (
-                self.DEFAULT_SOCKET_SERVER,
-                self.DEFAULT_SOCKET_PORT,
-            )
-        else:
-            socket_server, socket_port_or_file = config.split(':')
-            if socket_server == self.FILE_SOCKET_SERVER:
-                # Unix socket support is available on Windows 10 since April
-                # 2018. However, there is no Python support on Windows yet.
-                # See https://bugs.python.org/issue33408 for more information.
-                if not hasattr(socket, 'AF_UNIX'):
-                    raise TypeError(
-                        'Unix sockets are not supported in this environment.'
-                    )
-                self._connection_type = (
-                    socket.AF_UNIX  # pylint: disable=no-member
-                )
-                self._interface = socket_port_or_file
-            else:
-                self._connection_type = socket.AF_INET6
-                self._interface = (socket_server, int(socket_port_or_file))
+        """Creates a socket connection.
 
+        Args:
+          config: The socket configuration. Accepted values and formats are:
+            'default' - uses the default configuration (localhost:33000)
+            'address:port' - An IPv4 address and port.
+            'address' - An IPv4 address. Uses default port 33000.
+            '[address]:port' - An IPv6 address and port.
+            '[address]' - An IPv6 address. Uses default port 33000.
+            'file:path_to_file' - A Unix socket at ``path_to_file``.
+            In the formats above,``address`` can be an actual address or a name
+            that resolves to an address through name-resolution.
+          on_disconnect: An optional callback called when the socket
+            disconnects.
+
+        Raises:
+          TypeError: The type of socket is not supported.
+          ValueError: The socket configuration is invalid.
+        """
+        self.socket: socket.socket
+        (
+            self._socket_init_args,
+            self._address,
+        ) = SocketClient._parse_socket_config(config)
         self._on_disconnect = on_disconnect
         self._connected = False
         self.connect()
 
+    @staticmethod
+    def _parse_socket_config(
+        config: str,
+    ) -> Tuple[SocketClient._InitArgsType, SocketClient._AddressType]:
+        """Sets the variables used to create a socket given a config string.
+
+        Raises:
+          TypeError: The type of socket is not supported.
+          ValueError: The socket configuration is invalid.
+        """
+        init_args: SocketClient._InitArgsType
+        address: SocketClient._AddressType
+
+        # Check if this is using the default settings.
+        if config == 'default':
+            init_args = socket.AF_INET6, socket.SOCK_STREAM
+            address = (
+                SocketClient.DEFAULT_SOCKET_SERVER,
+                SocketClient.DEFAULT_SOCKET_PORT,
+            )
+            return init_args, address
+
+        # Check if this is a UNIX socket.
+        unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:'
+        if config.startswith(unix_socket_file_setting):
+            # Unix socket support is available on Windows 10 since April
+            # 2018. However, there is no Python support on Windows yet.
+            # See https://bugs.python.org/issue33408 for more information.
+            if not hasattr(socket, 'AF_UNIX'):
+                raise TypeError(
+                    'Unix sockets are not supported in this environment.'
+                )
+            init_args = (
+                socket.AF_UNIX,  # pylint: disable=no-member
+                socket.SOCK_STREAM,
+            )
+            address = config[len(unix_socket_file_setting) :]
+            return init_args, address
+
+        # Search for IPv4 or IPv6 address or name and port.
+        # First, try to capture an IPv6 address as anything inside []. If there
+        # are no [] capture the IPv4 address. Lastly, capture the port as the
+        # numbers after :, if any.
+        match = re.match(
+            r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)'
+            r'(?P<port>[0-9]+)?',
+            config,
+        )
+        invalid_config_message = (
+            f'Invalid socket configuration "{config}"'
+            'Accepted values are "default", "file:<file_path>", '
+            '"<name_or_ipv4_address>" with optional ":<port>", and '
+            '"[<name_or_ipv6_address>]" with optional ":<port>".'
+        )
+        if match is None:
+            raise ValueError(invalid_config_message)
+
+        info = match.groupdict()
+        if info['port']:
+            port = int(info['port'])
+        else:
+            port = SocketClient.DEFAULT_SOCKET_PORT
+
+        if info['ipv4_addr']:
+            ip_addr = info['ipv4_addr']
+        elif info['ipv6_addr']:
+            ip_addr = info['ipv6_addr']
+        else:
+            raise ValueError(invalid_config_message)
+
+        sock_family, sock_type, _, _, address = socket.getaddrinfo(
+            ip_addr, port, type=socket.SOCK_STREAM
+        )[0]
+        init_args = sock_family, sock_type
+        return init_args, address
+
     def __del__(self):
-        self.socket.close()
+        if self._connected:
+            self.socket.close()
 
     def write(self, data: ReadableBuffer) -> None:
         """Writes data and detects disconnects."""
@@ -96,13 +179,13 @@
 
     def connect(self) -> None:
         """Connects to socket."""
-        self.socket = socket.socket(self._connection_type, socket.SOCK_STREAM)
+        self.socket = socket.socket(*self._socket_init_args)
 
         # Enable reusing address and port for reconnections.
         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         if hasattr(socket, 'SO_REUSEPORT'):
             self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-        self.socket.connect(self._interface)
+        self.socket.connect(self._address)
         self._connected = True
 
     def _handle_disconnect(self):
diff --git a/pw_console/py/socket_client_test.py b/pw_console/py/socket_client_test.py
new file mode 100644
index 0000000..f4e5a9f
--- /dev/null
+++ b/pw_console/py/socket_client_test.py
@@ -0,0 +1,181 @@
+# Copyright 2023 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests for pw_console.socket_client"""
+
+import socket
+import unittest
+
+
+from pw_console import socket_client
+
+
+class TestSocketClient(unittest.TestCase):
+    """Tests for SocketClient."""
+
+    def test_parse_config_default(self) -> None:
+        config = "default"
+        with unittest.mock.patch.object(
+            socket_client.SocketClient, 'connect', return_value=None
+        ):
+            client = socket_client.SocketClient(config)
+            self.assertEqual(
+                client._socket_init_args,  # pylint: disable=protected-access
+                (socket.AF_INET6, socket.SOCK_STREAM),
+            )
+            self.assertEqual(
+                client._address,  # pylint: disable=protected-access
+                (
+                    socket_client.SocketClient.DEFAULT_SOCKET_SERVER,
+                    socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+                ),
+            )
+
+    def test_parse_config_unix_file(self) -> None:
+        # Skip test if UNIX sockets are not supported.
+        if not hasattr(socket, 'AF_UNIX'):
+            return
+
+        config = 'file:fake_file_path'
+        with unittest.mock.patch.object(
+            socket_client.SocketClient, 'connect', return_value=None
+        ):
+            client = socket_client.SocketClient(config)
+            self.assertEqual(
+                client._socket_init_args,  # pylint: disable=protected-access
+                (
+                    socket.AF_UNIX,  # pylint: disable=no-member
+                    socket.SOCK_STREAM,
+                ),
+            )
+            self.assertEqual(
+                client._address,  # pylint: disable=protected-access
+                'fake_file_path',
+            )
+
+    def _check_config_parsing(
+        self, config: str, expected_address: str, expected_port: int
+    ) -> None:
+        with unittest.mock.patch.object(
+            socket_client.SocketClient, 'connect', return_value=None
+        ):
+            fake_getaddrinfo_return_value = [
+                (socket.AF_INET6, socket.SOCK_STREAM, 0, None, None)
+            ]
+            with unittest.mock.patch.object(
+                socket,
+                'getaddrinfo',
+                return_value=fake_getaddrinfo_return_value,
+            ) as mock_getaddrinfo:
+                client = socket_client.SocketClient(config)
+                mock_getaddrinfo.assert_called_with(
+                    expected_address, expected_port, type=socket.SOCK_STREAM
+                )
+                # Assert the init args are what is returned by ``getaddrinfo``
+                # not necessarily the correct ones, since this test should not
+                # perform any network action.
+                self.assertEqual(
+                    client._socket_init_args,  # pylint: disable=protected-access
+                    (
+                        socket.AF_INET6,
+                        socket.SOCK_STREAM,
+                    ),
+                )
+
+    def test_parse_config_ipv4_domain(self) -> None:
+        self._check_config_parsing(
+            config='file.com/some_long/path:80',
+            expected_address='file.com/some_long/path',
+            expected_port=80,
+        )
+
+    def test_parse_config_ipv4_domain_no_port(self) -> None:
+        self._check_config_parsing(
+            config='file.com/some/path',
+            expected_address='file.com/some/path',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+    def test_parse_config_ipv4_address(self) -> None:
+        self._check_config_parsing(
+            config='8.8.8.8:8080',
+            expected_address='8.8.8.8',
+            expected_port=8080,
+        )
+
+    def test_parse_config_ipv4_address_no_port(self) -> None:
+        self._check_config_parsing(
+            config='8.8.8.8',
+            expected_address='8.8.8.8',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+    def test_parse_config_ipv6_domain(self) -> None:
+        self._check_config_parsing(
+            config='[file.com/some_long/path]:80',
+            expected_address='file.com/some_long/path',
+            expected_port=80,
+        )
+
+    def test_parse_config_ipv6_domain_no_port(self) -> None:
+        self._check_config_parsing(
+            config='[file.com/some/path]',
+            expected_address='file.com/some/path',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+    def test_parse_config_ipv6_address(self) -> None:
+        self._check_config_parsing(
+            config='[2001:4860:4860::8888:8080]:666',
+            expected_address='2001:4860:4860::8888:8080',
+            expected_port=666,
+        )
+
+    def test_parse_config_ipv6_address_no_port(self) -> None:
+        self._check_config_parsing(
+            config='[2001:4860:4860::8844]',
+            expected_address='2001:4860:4860::8844',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+    def test_parse_config_ipv6_local(self) -> None:
+        self._check_config_parsing(
+            config='[fe80::100%eth0]:80',
+            expected_address='fe80::100%eth0',
+            expected_port=80,
+        )
+
+    def test_parse_config_ipv6_local_no_port(self) -> None:
+        self._check_config_parsing(
+            config='[fe80::100%eth0]',
+            expected_address='fe80::100%eth0',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+    def test_parse_config_ipv6_local_windows(self) -> None:
+        self._check_config_parsing(
+            config='[fe80::100%4]:80',
+            expected_address='fe80::100%4',
+            expected_port=80,
+        )
+
+    def test_parse_config_ipv6_local_no_port_windows(self) -> None:
+        self._check_config_parsing(
+            config='[fe80::100%4]',
+            expected_address='fe80::100%4',
+            expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+        )
+
+
+if __name__ == '__main__':
+    unittest.main()