blob: 6995ef7fc6b1e01f8e750634ba179d6a77f9368b [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Defines a callback-based RPC ClientImpl to use with pw_rpc.client.Client.
callback_client.Impl supports invoking RPCs synchronously or asynchronously.
Asynchronous invocations use a callback.
Synchronous invocations look like a function call:
status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
# Streaming calls return an iterable of responses
for reply in client.channel(1).call.MyService.MyServerStreaming(request):
pass
Asynchronous invocations pass a callback in addition to the request. The
callback must be a callable that accepts a status and a payload, either of
which may be None. The Status is only set when the RPC is completed.
callback = lambda status, payload: print('Response:', status, payload)
call = client.channel(1).call.MyServer.MyUnary.invoke(
callback, some_field=123)
call = client.channel(1).call.MyService.MyServerStreaming.invoke(
callback, request):
When invoking a method, requests may be provided as a message object or as
kwargs for the message fields (but not both).
"""
import inspect
import logging
import queue
from typing import Any, Callable, Optional, Tuple
from pw_status import Status
from pw_rpc import client
from pw_rpc.descriptors import Channel, Method, Service
_LOG = logging.getLogger(__name__)
Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any]
class _MethodClient:
"""A method that can be invoked for a particular channel."""
def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs,
channel: Channel, method: Method):
self._impl = client_impl
self._rpcs = rpcs
self._rpc = client.PendingRpc(channel, method.service, method)
@property
def channel(self) -> Channel:
return self._rpc.channel
@property
def method(self) -> Method:
return self._rpc.method
@property
def service(self) -> Service:
return self._rpc.service
def invoke(self, callback: Callback, _request=None, **request_fields):
"""Invokes an RPC with a callback."""
self._rpcs.send_request(self._rpc,
self.method.get_request(
_request, request_fields),
callback,
override_pending=False)
return _AsyncCall(self._rpcs, self._rpc)
def reinvoke(self, callback: Callback, _request=None, **request_fields):
"""Invokes an RPC with a callback, overriding any pending requests."""
self._rpcs.send_request(self._rpc,
self.method.get_request(
_request, request_fields),
callback,
override_pending=True)
return _AsyncCall(self._rpcs, self._rpc)
def __repr__(self) -> str:
return repr(self.method)
class _AsyncCall:
"""Represents an ongoing callback-based call."""
# TODO(hepler): Consider alternatives (futures) and/or expand functionality.
def __init__(self, rpcs: client.PendingRpcs, rpc: client.PendingRpc):
self.rpc = rpc
self._rpcs = rpcs
def cancel(self) -> bool:
return self._rpcs.send_cancel(self.rpc)
def __enter__(self) -> '_AsyncCall':
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.cancel()
class _StreamingResponses:
"""Used to iterate over a queue.SimpleQueue."""
def __init__(self, responses: queue.SimpleQueue):
self._queue = responses
self.status: Optional[Status] = None
def get(self, *, block: bool = True, timeout_s: float = None):
while True:
self.status, response = self._queue.get(block, timeout_s)
if self.status is not None:
return
yield response
def __iter__(self):
return self.get()
def _method_client_docstring(method: Method) -> str:
return f'''\
Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
Calling this directly invokes the RPC synchronously. The RPC can be invoked
asynchronously using the invoke or reinvoke methods.
'''
def _function_docstring(method: Method) -> str:
return f'''\
Invokes the {method.full_name} {method.type.sentence_name()} RPC.
This function accepts either a request protobuf as a single positional argument
or the request fields as keyword arguments.
'''
def _update_function_signature(method: Method, function: Callable) -> None:
"""Updates the name, docstring, and parameters to match a method."""
function.__name__ = method.full_name
function.__doc__ = _function_docstring(method)
# In order to have good tab completion and help messages, update the
# function signature to accept only keyword arguments for the proto message
# fields. This doesn't actually change the function signature -- it just
# updates how it appears when inspected.
sig = inspect.signature(function)
params = [next(iter(sig.parameters.values()))] # Get the "self" parameter
params += method.request_parameters()
function.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=params)
def unary_method_client(client_impl: 'Impl', rpcs: client.PendingRpcs,
channel: Channel, method: Method) -> _MethodClient:
"""Creates an object used to call a unary method."""
def call(self,
_rpc_request_proto=None,
**request_fields) -> Tuple[Status, Any]:
responses: queue.SimpleQueue = queue.SimpleQueue()
self.reinvoke(
lambda _, status, payload: responses.put((status, payload)),
_rpc_request_proto, **request_fields)
return responses.get()
_update_function_signature(method, call)
# The MethodClient class is created dynamically so that the __call__ method
# can be configured differently for each method.
method_client_type = type(
f'{method.name}_UnaryMethodClient', (_MethodClient, ),
dict(__call__=call, __doc__=_method_client_docstring(method)))
return method_client_type(client_impl, rpcs, channel, method)
def server_streaming_method_client(client_impl: 'Impl',
rpcs: client.PendingRpcs, channel: Channel,
method: Method):
"""Creates an object used to call a server streaming method."""
def call(self,
_rpc_request_proto=None,
**request_fields) -> _StreamingResponses:
responses: queue.SimpleQueue = queue.SimpleQueue()
self.reinvoke(
lambda _, status, payload: responses.put((status, payload)),
_rpc_request_proto, **request_fields)
return _StreamingResponses(responses)
_update_function_signature(method, call)
# The MethodClient class is created dynamically so that the __call__ method
# can be configured differently for each method type.
method_client_type = type(
f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
dict(__call__=call, __doc__=_method_client_docstring(method)))
return method_client_type(client_impl, rpcs, channel, method)
class ClientStreamingMethodClient(_MethodClient):
def __call__(self):
raise NotImplementedError
def invoke(self, callback: Callback, _request=None, **request_fields):
raise NotImplementedError
class BidirectionalStreamingMethodClient(_MethodClient):
def __call__(self):
raise NotImplementedError
def invoke(self, callback: Callback, _request=None, **request_fields):
raise NotImplementedError
class Impl(client.ClientImpl):
"""Callback-based client.ClientImpl."""
def method_client(self, rpcs: client.PendingRpcs, channel: Channel,
method: Method) -> _MethodClient:
"""Returns an object that invokes a method using the given chanel."""
if method.type is Method.Type.UNARY:
return unary_method_client(self, rpcs, channel, method)
if method.type is Method.Type.SERVER_STREAMING:
return server_streaming_method_client(self, rpcs, channel, method)
if method.type is Method.Type.CLIENT_STREAMING:
return ClientStreamingMethodClient(self, rpcs, channel, method)
if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
return BidirectionalStreamingMethodClient(self, rpcs, channel,
method)
raise AssertionError(f'Unknown method type {method.type}')
def process_response(self,
rpcs: client.PendingRpcs,
rpc: client.PendingRpc,
context,
status: Optional[Status],
payload,
*,
args: tuple = (),
kwargs: dict = None) -> None:
"""Invokes the callback associated with this RPC.
Any additional positional and keyword args passed through
Client.process_packet are forwarded to the callback.
"""
if kwargs is None:
kwargs = {}
try:
context(rpc, status, payload, *args, **kwargs)
except: # pylint: disable=bare-except
rpcs.send_cancel(rpc)
_LOG.exception('Callback %s for %s raised exception', context, rpc)