blob: f4e5a9f5d73435230eccd346769ac1f3fe165ae8 [file] [log] [blame]
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests for pw_console.socket_client"""
import socket
import unittest
from pw_console import socket_client
class TestSocketClient(unittest.TestCase):
"""Tests for SocketClient."""
def test_parse_config_default(self) -> None:
config = "default"
with unittest.mock.patch.object(
socket_client.SocketClient, 'connect', return_value=None
):
client = socket_client.SocketClient(config)
self.assertEqual(
client._socket_init_args, # pylint: disable=protected-access
(socket.AF_INET6, socket.SOCK_STREAM),
)
self.assertEqual(
client._address, # pylint: disable=protected-access
(
socket_client.SocketClient.DEFAULT_SOCKET_SERVER,
socket_client.SocketClient.DEFAULT_SOCKET_PORT,
),
)
def test_parse_config_unix_file(self) -> None:
# Skip test if UNIX sockets are not supported.
if not hasattr(socket, 'AF_UNIX'):
return
config = 'file:fake_file_path'
with unittest.mock.patch.object(
socket_client.SocketClient, 'connect', return_value=None
):
client = socket_client.SocketClient(config)
self.assertEqual(
client._socket_init_args, # pylint: disable=protected-access
(
socket.AF_UNIX, # pylint: disable=no-member
socket.SOCK_STREAM,
),
)
self.assertEqual(
client._address, # pylint: disable=protected-access
'fake_file_path',
)
def _check_config_parsing(
self, config: str, expected_address: str, expected_port: int
) -> None:
with unittest.mock.patch.object(
socket_client.SocketClient, 'connect', return_value=None
):
fake_getaddrinfo_return_value = [
(socket.AF_INET6, socket.SOCK_STREAM, 0, None, None)
]
with unittest.mock.patch.object(
socket,
'getaddrinfo',
return_value=fake_getaddrinfo_return_value,
) as mock_getaddrinfo:
client = socket_client.SocketClient(config)
mock_getaddrinfo.assert_called_with(
expected_address, expected_port, type=socket.SOCK_STREAM
)
# Assert the init args are what is returned by ``getaddrinfo``
# not necessarily the correct ones, since this test should not
# perform any network action.
self.assertEqual(
client._socket_init_args, # pylint: disable=protected-access
(
socket.AF_INET6,
socket.SOCK_STREAM,
),
)
def test_parse_config_ipv4_domain(self) -> None:
self._check_config_parsing(
config='file.com/some_long/path:80',
expected_address='file.com/some_long/path',
expected_port=80,
)
def test_parse_config_ipv4_domain_no_port(self) -> None:
self._check_config_parsing(
config='file.com/some/path',
expected_address='file.com/some/path',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
def test_parse_config_ipv4_address(self) -> None:
self._check_config_parsing(
config='8.8.8.8:8080',
expected_address='8.8.8.8',
expected_port=8080,
)
def test_parse_config_ipv4_address_no_port(self) -> None:
self._check_config_parsing(
config='8.8.8.8',
expected_address='8.8.8.8',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
def test_parse_config_ipv6_domain(self) -> None:
self._check_config_parsing(
config='[file.com/some_long/path]:80',
expected_address='file.com/some_long/path',
expected_port=80,
)
def test_parse_config_ipv6_domain_no_port(self) -> None:
self._check_config_parsing(
config='[file.com/some/path]',
expected_address='file.com/some/path',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
def test_parse_config_ipv6_address(self) -> None:
self._check_config_parsing(
config='[2001:4860:4860::8888:8080]:666',
expected_address='2001:4860:4860::8888:8080',
expected_port=666,
)
def test_parse_config_ipv6_address_no_port(self) -> None:
self._check_config_parsing(
config='[2001:4860:4860::8844]',
expected_address='2001:4860:4860::8844',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
def test_parse_config_ipv6_local(self) -> None:
self._check_config_parsing(
config='[fe80::100%eth0]:80',
expected_address='fe80::100%eth0',
expected_port=80,
)
def test_parse_config_ipv6_local_no_port(self) -> None:
self._check_config_parsing(
config='[fe80::100%eth0]',
expected_address='fe80::100%eth0',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
def test_parse_config_ipv6_local_windows(self) -> None:
self._check_config_parsing(
config='[fe80::100%4]:80',
expected_address='fe80::100%4',
expected_port=80,
)
def test_parse_config_ipv6_local_no_port_windows(self) -> None:
self._check_config_parsing(
config='[fe80::100%4]',
expected_address='fe80::100%4',
expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
)
if __name__ == '__main__':
unittest.main()