blob: 7e71420264f38b11fe8285b84daaf962761e2e7f [file] [log] [blame]
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
callback_client.Impl supports invoking RPCs synchronously or asynchronously.
Asynchronous invocations use a callback.
Synchronous invocations look like a function call:
status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
# Streaming calls return an iterable of responses
for reply in client.channel(1).call.MyService.MyServerStreaming(request):
pass
Asynchronous invocations pass a callback in addition to the request. The
callback must be a callable that accepts a status and a payload, either of
which may be None. The Status is only set when the RPC is completed.
callback = lambda status, payload: print('Response:', status, payload)
call = client.channel(1).call.MyServer.MyUnary.invoke(
callback, some_field=123)
call = client.channel(1).call.MyService.MyServerStreaming.invoke(
callback, request):
When invoking a method, requests may be provided as a message object or as
kwargs for the message fields (but not both).
"""
import enum
import inspect
import logging
import queue
import textwrap
import threading
from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
from pw_protobuf_compiler.python_protos import proto_repr
from pw_status import Status
from pw_rpc import client, descriptors
from pw_rpc.client import PendingRpc, PendingRpcs
from pw_rpc.descriptors import Channel, Method, Service
_LOG = logging.getLogger(__name__)
class UseDefault(enum.Enum):
"""Marker for args that should use a default value, when None is valid."""
VALUE = 0
OptionalTimeout = Union[UseDefault, float, None]
ResponseCallback = Callable[[PendingRpc, Any], Any]
CompletionCallback = Callable[[PendingRpc, Status], Any]
ErrorCallback = Callable[[PendingRpc, Status], Any]
class _Callbacks(NamedTuple):
response: ResponseCallback
completion: CompletionCallback
error: ErrorCallback
def _default_response(rpc: PendingRpc, response: Any) -> None:
_LOG.info('%s response: %s', rpc, response)
def _default_completion(rpc: PendingRpc, status: Status) -> None:
_LOG.info('%s finished: %s', rpc, status)
def _default_error(rpc: PendingRpc, status: Status) -> None:
_LOG.error('%s error: %s', rpc, status)
class _MethodClient:
"""A method that can be invoked for a particular channel."""
def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
channel: Channel, method: Method,
default_timeout_s: Optional[float]):
self._impl = client_impl
self._rpcs = rpcs
self._rpc = PendingRpc(channel, method.service, method)
self.default_timeout_s: Optional[float] = default_timeout_s
@property
def channel(self) -> Channel:
return self._rpc.channel
@property
def method(self) -> Method:
return self._rpc.method
@property
def service(self) -> Service:
return self._rpc.service
def invoke(self,
request: Any,
response: ResponseCallback = _default_response,
completion: CompletionCallback = _default_completion,
error: ErrorCallback = _default_error,
override_pending: bool = True) -> '_AsyncCall':
"""Invokes an RPC with callbacks."""
self._rpcs.send_request(self._rpc,
request,
_Callbacks(response, completion, error),
override_pending=override_pending)
return _AsyncCall(self._rpcs, self._rpc)
def __repr__(self) -> str:
return self.help()
def __call__(self):
raise NotImplementedError('Implemented by derived classes')
def help(self) -> str:
"""Returns a help message about this RPC."""
function_call = self.method.full_name + '('
docstring = inspect.getdoc(self.__call__)
assert docstring is not None
annotation = inspect.Signature.from_callable(self).return_annotation
if isinstance(annotation, type):
annotation = annotation.__name__
arg_sep = f',\n{" " * len(function_call)}'
return (
f'{function_call}'
f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
f'\n\n{textwrap.indent(docstring, " ")}\n\n'
f' Returns {annotation}.')
class RpcTimeout(Exception):
def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
super().__init__(
f'No response received for {rpc.method} after {timeout} s')
self.rpc = rpc
self.timeout = timeout
class RpcError(Exception):
def __init__(self, rpc: PendingRpc, status: Status):
if status is Status.NOT_FOUND:
msg = ': the RPC server does not support this RPC'
else:
msg = ''
super().__init__(f'{rpc.method} failed with error {status}{msg}')
self.rpc = rpc
self.status = status
class _AsyncCall:
"""Represents an ongoing callback-based call."""
# TODO(hepler): Consider alternatives (futures) and/or expand functionality.
def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
self._rpc = rpc
self._rpcs = rpcs
def cancel(self) -> bool:
return self._rpcs.send_cancel(self._rpc)
def __enter__(self) -> '_AsyncCall':
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.cancel()
class StreamingResponses:
"""Used to iterate over a queue.SimpleQueue."""
def __init__(self, method_client: _MethodClient,
responses: queue.SimpleQueue,
default_timeout_s: OptionalTimeout):
self._method_client = method_client
self._queue = responses
self.status: Optional[Status] = None
if default_timeout_s is UseDefault.VALUE:
self.default_timeout_s = self._method_client.default_timeout_s
else:
self.default_timeout_s = default_timeout_s
@property
def method(self) -> Method:
return self._method_client.method
def responses(self,
*,
block: bool = True,
timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
"""Returns an iterator of stream responses.
Args:
timeout_s: timeout in seconds; None blocks indefinitely
"""
if timeout_s is UseDefault.VALUE:
timeout_s = self.default_timeout_s
try:
while True:
response = self._queue.get(block, timeout_s)
if isinstance(response, Exception):
raise response
if isinstance(response, Status):
self.status = response
return
yield response
except queue.Empty:
pass
raise RpcTimeout(self._method_client._rpc, timeout_s) # pylint: disable=protected-access
def __iter__(self):
return self.responses()
def __repr__(self) -> str:
return f'{type(self).__name__}({self.method})'
def _method_client_docstring(method: Method) -> str:
return f'''\
Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
Calling this directly invokes the RPC synchronously. The RPC can be invoked
asynchronously using the invoke method.
'''
def _function_docstring(method: Method) -> str:
return f'''\
Invokes the {method.full_name} {method.type.sentence_name()} RPC.
This function accepts either the request protobuf fields as keyword arguments or
a request protobuf as a positional argument.
'''
def _update_function_signature(method: Method, function: Callable) -> None:
"""Updates the name, docstring, and parameters to match a method."""
function.__name__ = method.full_name
function.__doc__ = _function_docstring(method)
# In order to have good tab completion and help messages, update the
# function signature to accept only keyword arguments for the proto message
# fields. This doesn't actually change the function signature -- it just
# updates how it appears when inspected.
sig = inspect.signature(function)
params = [next(iter(sig.parameters.values()))] # Get the "self" parameter
params += method.request_parameters()
params.append(
inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
function.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=params)
class UnaryResponse(NamedTuple):
"""Result of invoking a unary RPC: status and response."""
status: Status
response: Any
def __repr__(self) -> str:
return f'({self.status}, {proto_repr(self.response)})'
class _UnaryResponseHandler:
"""Tracks the state of an ongoing synchronous unary RPC call."""
def __init__(self, rpc: PendingRpc):
self._rpc = rpc
self._response: Any = None
self._status: Optional[Status] = None
self._error: Optional[RpcError] = None
self._event = threading.Event()
def on_response(self, _: PendingRpc, response: Any) -> None:
self._response = response
def on_completion(self, _: PendingRpc, status: Status) -> None:
self._status = status
self._event.set()
def on_error(self, _: PendingRpc, status: Status) -> None:
self._error = RpcError(self._rpc, status)
self._event.set()
def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
if not self._event.wait(timeout_s):
raise RpcTimeout(self._rpc, timeout_s)
if self._error is not None:
raise self._error
assert self._status is not None
return UnaryResponse(self._status, self._response)
def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
channel: Channel, method: Method,
default_timeout: Optional[float]) -> _MethodClient:
"""Creates an object used to call a unary method."""
def call(self: _MethodClient,
_rpc_request_proto=None,
*,
pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> UnaryResponse:
handler = _UnaryResponseHandler(self._rpc) # pylint: disable=protected-access
self.invoke(
self.method.get_request(_rpc_request_proto, request_fields),
handler.on_response, handler.on_completion, handler.on_error)
if pw_rpc_timeout_s is UseDefault.VALUE:
pw_rpc_timeout_s = self.default_timeout_s
return handler.wait(pw_rpc_timeout_s)
_update_function_signature(method, call)
# The MethodClient class is created dynamically so that the __call__ method
# can be configured differently for each method.
method_client_type = type(
f'{method.name}_UnaryMethodClient', (_MethodClient, ),
dict(__call__=call, __doc__=_method_client_docstring(method)))
return method_client_type(client_impl, rpcs, channel, method,
default_timeout)
def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
channel: Channel, method: Method,
default_timeout: Optional[float]):
"""Creates an object used to call a server streaming method."""
def call(self: _MethodClient,
_rpc_request_proto=None,
*,
pw_rpc_timeout_s=UseDefault.VALUE,
**request_fields) -> StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
self.invoke(
self.method.get_request(_rpc_request_proto, request_fields),
lambda _, response: responses.put(response),
lambda _, status: responses.put(status),
lambda rpc, status: responses.put(RpcError(rpc, status)))
return StreamingResponses(self, responses, pw_rpc_timeout_s)
_update_function_signature(method, call)
# The MethodClient class is created dynamically so that the __call__ method
# can be configured differently for each method type.
method_client_type = type(
f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
dict(__call__=call, __doc__=_method_client_docstring(method)))
return method_client_type(client_impl, rpcs, channel, method,
default_timeout)
class ClientStreamingMethodClient(_MethodClient):
def __call__(self):
raise NotImplementedError
def invoke(self,
request: Any,
response: ResponseCallback = _default_response,
completion: CompletionCallback = _default_completion,
error: ErrorCallback = _default_error,
override_pending: bool = True) -> _AsyncCall:
raise NotImplementedError
class BidirectionalStreamingMethodClient(_MethodClient):
def __call__(self):
raise NotImplementedError
def invoke(self,
request: Any,
response: ResponseCallback = _default_response,
completion: CompletionCallback = _default_completion,
error: ErrorCallback = _default_error,
override_pending: bool = True) -> _AsyncCall:
raise NotImplementedError
class Impl(client.ClientImpl):
"""Callback-based ClientImpl."""
def __init__(self,
default_unary_timeout_s: Optional[float] = 1.0,
default_stream_timeout_s: Optional[float] = 1.0):
super().__init__()
self._default_unary_timeout_s = default_unary_timeout_s
self._default_stream_timeout_s = default_stream_timeout_s
@property
def default_unary_timeout_s(self) -> Optional[float]:
return self._default_unary_timeout_s
@property
def default_stream_timeout_s(self) -> Optional[float]:
return self._default_stream_timeout_s
def method_client(self, channel: Channel, method: Method) -> _MethodClient:
"""Returns an object that invokes a method using the given chanel."""
if method.type is Method.Type.UNARY:
return _unary_method_client(self, self.rpcs, channel, method,
self.default_unary_timeout_s)
if method.type is Method.Type.SERVER_STREAMING:
return _server_streaming_method_client(
self, self.rpcs, channel, method,
self.default_stream_timeout_s)
if method.type is Method.Type.CLIENT_STREAMING:
return ClientStreamingMethodClient(self, self.rpcs, channel,
method,
self.default_unary_timeout_s)
if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
return BidirectionalStreamingMethodClient(
self, self.rpcs, channel, method,
self.default_stream_timeout_s)
raise AssertionError(f'Unknown method type {method.type}')
def handle_response(self,
rpc: PendingRpc,
context,
payload,
*,
args: tuple = (),
kwargs: dict = None) -> None:
"""Invokes the callback associated with this RPC.
Any additional positional and keyword args passed through
Client.process_packet are forwarded to the callback.
"""
if kwargs is None:
kwargs = {}
try:
context.response(rpc, payload, *args, **kwargs)
except: # pylint: disable=bare-except
self.rpcs.send_cancel(rpc)
_LOG.exception('Response callback %s for %s raised exception',
context.response, rpc)
def handle_completion(self,
rpc: PendingRpc,
context,
status: Status,
*,
args: tuple = (),
kwargs: dict = None):
if kwargs is None:
kwargs = {}
try:
context.completion(rpc, status, *args, **kwargs)
except: # pylint: disable=bare-except
_LOG.exception('Completion callback %s for %s raised exception',
context.completion, rpc)
def handle_error(self,
rpc: PendingRpc,
context,
status: Status,
*,
args: tuple = (),
kwargs: dict = None) -> None:
if kwargs is None:
kwargs = {}
try:
context.error(rpc, status, *args, **kwargs)
except: # pylint: disable=bare-except
_LOG.exception('Error callback %s for %s raised exception',
context.error, rpc)