blob: 5928dabb6ef991157be3d2d8552f9ad0025f9031 [file] [log] [blame]
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Wrapers for socket clients to log read and write data."""
from __future__ import annotations
from typing import Callable, TYPE_CHECKING
import errno
import re
import socket
from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker
if TYPE_CHECKING:
from _typeshed import ReadableBuffer
class SocketClient:
"""Socket transport implementation."""
FILE_SOCKET_SERVER = 'file'
DEFAULT_SOCKET_SERVER = 'localhost'
DEFAULT_SOCKET_PORT = 33000
PW_RPC_MAX_PACKET_SIZE = 256
DEFAULT_TIMEOUT = 0.5
_InitArgsType = tuple[
socket.AddressFamily, int # pylint: disable=no-member
]
# Can be a string, (address, port) for AF_INET or (address, port, flowinfo,
# scope_id) AF_INET6.
_AddressType = str | tuple[str, int] | tuple[str, int, int, int]
def __init__(
self,
config: str,
on_disconnect: Callable[[SocketClient], None] | None = None,
timeout: float | None = None,
):
"""Creates a socket connection.
Args:
config: The socket configuration. Accepted values and formats are:
'default' - uses the default configuration (localhost:33000)
'address:port' - An IPv4 address and port.
'address' - An IPv4 address. Uses default port 33000.
'[address]:port' - An IPv6 address and port.
'[address]' - An IPv6 address. Uses default port 33000.
'file:path_to_file' - A Unix socket at ``path_to_file``.
In the formats above,``address`` can be an actual address or a name
that resolves to an address through name-resolution.
on_disconnect: An optional callback called when the socket
disconnects.
Raises:
TypeError: The type of socket is not supported.
ValueError: The socket configuration is invalid.
"""
self.socket: socket.socket
(
self._socket_init_args,
self._address,
) = SocketClient._parse_socket_config(config)
self._on_disconnect = on_disconnect
self._timeout = SocketClient.DEFAULT_TIMEOUT
if timeout:
self._timeout = timeout
self._connected = False
self.connect()
@staticmethod
def _parse_socket_config(
config: str,
) -> tuple[SocketClient._InitArgsType, SocketClient._AddressType]:
"""Sets the variables used to create a socket given a config string.
Raises:
TypeError: The type of socket is not supported.
ValueError: The socket configuration is invalid.
"""
init_args: SocketClient._InitArgsType
address: SocketClient._AddressType
# Check if this is using the default settings.
if config == 'default':
init_args = socket.AF_INET6, socket.SOCK_STREAM
address = (
SocketClient.DEFAULT_SOCKET_SERVER,
SocketClient.DEFAULT_SOCKET_PORT,
)
return init_args, address
# Check if this is a UNIX socket.
unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:'
if config.startswith(unix_socket_file_setting):
# Unix socket support is available on Windows 10 since April
# 2018. However, there is no Python support on Windows yet.
# See https://bugs.python.org/issue33408 for more information.
if not hasattr(socket, 'AF_UNIX'):
raise TypeError(
'Unix sockets are not supported in this environment.'
)
init_args = (
socket.AF_UNIX, # pylint: disable=no-member
socket.SOCK_STREAM,
)
address = config[len(unix_socket_file_setting) :]
return init_args, address
# Search for IPv4 or IPv6 address or name and port.
# First, try to capture an IPv6 address as anything inside []. If there
# are no [] capture the IPv4 address. Lastly, capture the port as the
# numbers after :, if any.
match = re.match(
r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)'
r'(?P<port>[0-9]+)?',
config,
)
invalid_config_message = (
f'Invalid socket configuration "{config}"'
'Accepted values are "default", "file:<file_path>", '
'"<name_or_ipv4_address>" with optional ":<port>", and '
'"[<name_or_ipv6_address>]" with optional ":<port>".'
)
if match is None:
raise ValueError(invalid_config_message)
info = match.groupdict()
if info['port']:
port = int(info['port'])
else:
port = SocketClient.DEFAULT_SOCKET_PORT
if info['ipv4_addr']:
ip_addr = info['ipv4_addr']
elif info['ipv6_addr']:
ip_addr = info['ipv6_addr']
else:
raise ValueError(invalid_config_message)
sock_family, sock_type, _, _, address = socket.getaddrinfo(
ip_addr, port, type=socket.SOCK_STREAM
)[0]
init_args = sock_family, sock_type
return init_args, address
def __del__(self):
if self._connected:
self.socket.close()
def write(self, data: ReadableBuffer) -> None:
"""Writes data and detects disconnects."""
if not self._connected:
raise Exception('Socket is not connected.')
try:
self.socket.sendall(data)
except socket.error as exc:
if isinstance(exc.args, tuple) and exc.args[0] == errno.EPIPE:
self._handle_disconnect()
else:
raise exc
def read(self, num_bytes: int = PW_RPC_MAX_PACKET_SIZE) -> bytes:
"""Blocks until data is ready and reads up to num_bytes."""
if not self._connected:
raise Exception('Socket is not connected.')
data = self.socket.recv(num_bytes)
# Since this is a blocking read, no data returned means the socket is
# closed.
if not data:
self._handle_disconnect()
return data
def connect(self) -> None:
"""Connects to socket."""
self.socket = socket.socket(*self._socket_init_args)
# Enable reusing address and port for reconnections.
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
self.socket.settimeout(self._timeout)
self.socket.connect(self._address)
self._connected = True
def _handle_disconnect(self):
"""Escalates a socket disconnect to the user."""
self.socket.close()
self._connected = False
if self._on_disconnect:
self._on_disconnect(self)
def fileno(self) -> int:
return self.socket.fileno()
class SocketClientWithLogging(SocketClient):
"""Socket with read and write wrappers for logging."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._bandwidth_tracker = SerialBandwidthTracker()
def read(
self, num_bytes: int = SocketClient.PW_RPC_MAX_PACKET_SIZE
) -> bytes:
data = super().read(num_bytes)
self._bandwidth_tracker.track_read_data(data)
return data
def write(self, data: ReadableBuffer) -> None:
self._bandwidth_tracker.track_write_data(data)
super().write(data)