blob: fdae73211c2cf6f60891579f5a64b0d18d5abc78 [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Types representing the basic pw_rpc concepts: channel, service, method."""
from dataclasses import dataclass
import enum
from inspect import Parameter
from typing import (Any, Callable, Collection, Dict, Generic, Iterable,
Iterator, Optional, Tuple, TypeVar, Union)
from google.protobuf import descriptor_pb2, message_factory
from google.protobuf.descriptor import (FieldDescriptor, MethodDescriptor,
ServiceDescriptor)
from google.protobuf.message import Message
from pw_protobuf_compiler import python_protos
from pw_rpc import ids
@dataclass(frozen=True)
class Channel:
id: int
output: Callable[[bytes], Any]
def __repr__(self) -> str:
return f'Channel({self.id})'
@dataclass(frozen=True, eq=False)
class Service:
"""Describes an RPC service."""
_descriptor: ServiceDescriptor
id: int
methods: 'Methods'
@property
def name(self):
return self._descriptor.name
@property
def full_name(self):
return self._descriptor.full_name
@property
def package(self):
return self._descriptor.file.package
@classmethod
def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service':
service = cls(descriptor, ids.calculate(descriptor.full_name),
None) # type: ignore[arg-type]
object.__setattr__(
service, 'methods',
Methods(
Method.from_descriptor(method_descriptor, service)
for method_descriptor in descriptor.methods))
return service
def __repr__(self) -> str:
return f'Service({self.full_name!r})'
def __str__(self) -> str:
return self.full_name
def _streaming_attributes(method) -> Tuple[bool, bool]:
# TODO(hepler): Investigate adding server_streaming and client_streaming
# attributes to the generated protobuf code. As a workaround,
# deserialize the FileDescriptorProto to get that information.
service = method.containing_service
file_pb = descriptor_pb2.FileDescriptorProto()
file_pb.MergeFromString(service.file.serialized_pb)
method_pb = file_pb.service[service.index].method[method.index] # pylint: disable=no-member
return method_pb.server_streaming, method_pb.client_streaming
_PROTO_FIELD_TYPES = {
FieldDescriptor.TYPE_BOOL: bool,
FieldDescriptor.TYPE_BYTES: bytes,
FieldDescriptor.TYPE_DOUBLE: float,
FieldDescriptor.TYPE_ENUM: int,
FieldDescriptor.TYPE_FIXED32: int,
FieldDescriptor.TYPE_FIXED64: int,
FieldDescriptor.TYPE_FLOAT: float,
FieldDescriptor.TYPE_INT32: int,
FieldDescriptor.TYPE_INT64: int,
FieldDescriptor.TYPE_SFIXED32: int,
FieldDescriptor.TYPE_SFIXED64: int,
FieldDescriptor.TYPE_SINT32: int,
FieldDescriptor.TYPE_SINT64: int,
FieldDescriptor.TYPE_STRING: str,
FieldDescriptor.TYPE_UINT32: int,
FieldDescriptor.TYPE_UINT64: int,
# These types are not annotated:
# FieldDescriptor.TYPE_GROUP = 10
# FieldDescriptor.TYPE_MESSAGE = 11
}
def _field_type_annotation(field: FieldDescriptor):
"""Creates a field type annotation to use in the help message only."""
if field.type == FieldDescriptor.TYPE_MESSAGE:
annotation = message_factory.MessageFactory(
field.message_type.file.pool).GetPrototype(field.message_type)
else:
annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
if field.label == FieldDescriptor.LABEL_REPEATED:
return Iterable[annotation] # type: ignore[valid-type]
return annotation
def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]:
"""Yields argument strings for proto fields for use in a help message."""
for field in proto_message.DESCRIPTOR.fields:
if field.type == FieldDescriptor.TYPE_ENUM:
value = field.enum_type.values_by_number[field.default_value].name
type_name = field.enum_type.full_name
value = f'{type_name.rsplit(".", 1)[0]}.{value}'
else:
type_name = _PROTO_FIELD_TYPES[field.type].__name__
value = repr(field.default_value)
if annotations:
yield f'{field.name}: {type_name} = {value}'
else:
yield f'{field.name}={value}'
def _message_is_type(proto, expected_type) -> bool:
"""Returns true if the protobuf instance is the expected type."""
# Getting protobuf classes from google.protobuf.message_factory may create a
# new, unique generated proto class. Any generated classes for a particular
# proto message share the same MessageDescriptor instance and are
# interchangeable, so check the descriptors in addition to the types.
return isinstance(proto, expected_type) or (isinstance(
proto, Message) and proto.DESCRIPTOR is expected_type.DESCRIPTOR)
@dataclass(frozen=True, eq=False)
class Method:
"""Describes a method in a service."""
_descriptor: MethodDescriptor
service: Service
id: int
server_streaming: bool
client_streaming: bool
request_type: Any
response_type: Any
@classmethod
def from_descriptor(cls, descriptor: MethodDescriptor, service: Service):
input_factory = message_factory.MessageFactory(
descriptor.input_type.file.pool)
output_factory = message_factory.MessageFactory(
descriptor.output_type.file.pool)
return Method(
descriptor,
service,
ids.calculate(descriptor.name),
*_streaming_attributes(descriptor),
input_factory.GetPrototype(descriptor.input_type),
output_factory.GetPrototype(descriptor.output_type),
)
class Type(enum.Enum):
UNARY = 0
SERVER_STREAMING = 1
CLIENT_STREAMING = 2
BIDIRECTIONAL_STREAMING = 3
def sentence_name(self) -> str:
return self.name.lower().replace('_', ' ') # pylint: disable=no-member
@property
def name(self) -> str:
return self._descriptor.name
@property
def full_name(self) -> str:
return self._descriptor.full_name
@property
def package(self) -> str:
return self._descriptor.containing_service.file.package
@property
def type(self) -> 'Method.Type':
if self.server_streaming and self.client_streaming:
return self.Type.BIDIRECTIONAL_STREAMING
if self.server_streaming:
return self.Type.SERVER_STREAMING
if self.client_streaming:
return self.Type.CLIENT_STREAMING
return self.Type.UNARY
def get_request(self, proto: Optional[Message],
proto_kwargs: Optional[Dict[str, Any]]) -> Message:
"""Returns a request_type protobuf message.
The client implementation may use this to support providing a request
as either a message object or as keyword arguments for the message's
fields (but not both).
"""
if proto_kwargs is None:
proto_kwargs = {}
if proto and proto_kwargs:
proto_str = repr(proto).strip() or "''"
raise TypeError(
'Requests must be provided either as a message object or a '
'series of keyword args, but both were provided '
f"({proto_str} and {proto_kwargs!r})")
if proto is None:
return self.request_type(**proto_kwargs)
if not _message_is_type(proto, self.request_type):
try:
bad_type = proto.DESCRIPTOR.full_name
except AttributeError:
bad_type = type(proto).__name__
raise TypeError(f'Expected a message of type '
f'{self.request_type.DESCRIPTOR.full_name}, '
f'got {bad_type}')
return proto
def request_parameters(self) -> Iterator[Parameter]:
"""Yields inspect.Parameters corresponding to the request's fields.
This can be used to make function signatures match the request proto.
"""
for field in self.request_type.DESCRIPTOR.fields:
yield Parameter(field.name,
Parameter.KEYWORD_ONLY,
annotation=_field_type_annotation(field),
default=field.default_value)
def __repr__(self) -> str:
req = self._method_parameter(self.request_type, self.client_streaming)
res = self._method_parameter(self.response_type, self.server_streaming)
return f'<{self.full_name}({req}) returns ({res})>'
def _method_parameter(self, proto, streaming: bool) -> str:
"""Returns a description of the method's request or response type."""
stream = 'stream ' if streaming else ''
if proto.DESCRIPTOR.file.package == self.service.package:
return stream + proto.DESCRIPTOR.name
return stream + proto.DESCRIPTOR.full_name
def __str__(self) -> str:
return self.full_name
T = TypeVar('T')
def _name(item: Union[Service, Method]) -> str:
return item.full_name if isinstance(item, Service) else item.name
class _AccessByName(Generic[T]):
"""Wrapper for accessing types by name within a proto package structure."""
def __init__(self, name: str, item: T):
setattr(self, name, item)
class ServiceAccessor(Collection[T]):
"""Navigates RPC services by name or ID."""
def __init__(self, members, as_attrs: str = ''):
"""Creates accessor from an {item: value} dict or [values] iterable."""
# If the members arg was passed as a [values] iterable, convert it to
# an equivalent dictionary.
if not isinstance(members, dict):
members = {m: m for m in members}
by_name = {_name(k): v for k, v in members.items()}
self._by_id = {k.id: v for k, v in members.items()}
if as_attrs == 'members':
for name, member in by_name.items():
setattr(self, name, member)
elif as_attrs == 'packages':
for package in python_protos.as_packages(
(m.package, _AccessByName(m.name, members[m]))
for m in members).packages:
setattr(self, str(package), package)
elif as_attrs:
raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
def __getitem__(self, name_or_id: Union[str, int]):
"""Accesses a service/method by the string name or ID."""
try:
return self._by_id[_id(name_or_id)]
except KeyError:
pass
name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
def __iter__(self) -> Iterator[T]:
return iter(self._by_id.values())
def __len__(self) -> int:
return len(self._by_id)
def __contains__(self, name_or_id) -> bool:
return _id(name_or_id) in self._by_id
def __repr__(self) -> str:
members = ', '.join(repr(m) for m in self._by_id.values())
return f'{self.__class__.__name__}({members})'
def _id(handle: Union[str, int]) -> int:
return ids.calculate(handle) if isinstance(handle, str) else handle
class Methods(ServiceAccessor[Method]):
"""A collection of Method descriptors in a Service."""
def __init__(self, method: Iterable[Method]):
super().__init__(method)
class Services(ServiceAccessor[Service]):
"""A collection of Service descriptors."""
def __init__(self, services: Iterable[Service]):
super().__init__(services)
def get_method(service_accessor: ServiceAccessor, name: str):
"""Returns a method matching the given full name in a ServiceAccessor.
Args:
name: name as package.Service/Method or package.Service.Method.
Raises:
ValueError: the method name is not properly formatted
KeyError: the method is not present
"""
if '/' in name:
service_name, method_name = name.split('/')
else:
service_name, method_name = name.rsplit('.', 1)
service = service_accessor[service_name]
if isinstance(service, Service):
service = service.methods
return service[method_name]