blob: 4c6009f24b2535687e497c1bdfdc95bfd4ba6ca3 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests using the callback client for pw_rpc."""
import unittest
from unittest import mock
from typing import List, Tuple
from pw_protobuf_compiler import python_protos
from pw_status import Status
from pw_rpc import callback_client, client, packets
from pw_rpc.internal import packet_pb2
TEST_PROTO_1 = """\
syntax = "proto3";
package pw.test1;
message SomeMessage {
uint32 magic_number = 1;
}
message AnotherMessage {
enum Result {
FAILED = 0;
FAILED_MISERABLY = 1;
I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
}
Result result = 1;
string payload = 2;
}
service PublicService {
rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
}
"""
def _rpc(method_stub):
return client.PendingRpc(method_stub.channel, method_stub.method.service,
method_stub.method)
def _message_bytes(msg) -> bytes:
return msg if isinstance(msg, bytes) else msg.SerializeToString()
class CallbackClientImplTest(unittest.TestCase):
"""Tests the callback_client as used within a pw_rpc Client."""
def setUp(self):
self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
self._request = self._protos.packages.pw.test1.SomeMessage
self._client = client.Client.from_modules(
callback_client.Impl(), [client.Channel(1, self._handle_request)],
self._protos.modules())
self._service = self._client.channel(1).rpcs.pw.test1.PublicService
self._last_request: packet_pb2.RpcPacket = None
self._next_packets: List[Tuple[bytes, Status]] = []
self._send_responses_on_request = True
def _enqueue_response(self,
channel_id: int,
method=None,
status: Status = Status.OK,
payload=b'',
*,
ids: Tuple[int, int] = None,
process_status=Status.OK):
if method:
assert ids is None
service_id, method_id = method.service.id, method.id
else:
assert ids is not None and method is None
service_id, method_id = ids
self._next_packets.append((packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
channel_id=channel_id,
service_id=service_id,
method_id=method_id,
status=status.value,
payload=_message_bytes(payload)).SerializeToString(),
process_status))
def _enqueue_server_stream(self,
channel_id: int,
method,
response,
process_status=Status.OK):
self._next_packets.append((packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_STREAM,
channel_id=channel_id,
service_id=method.service.id,
method_id=method.id,
payload=_message_bytes(response)).SerializeToString(),
process_status))
def _enqueue_error(self,
channel_id: int,
method,
status: Status,
process_status=Status.OK):
self._next_packets.append(
(packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
channel_id=channel_id,
service_id=method.service.id,
method_id=method.id,
status=status.value).SerializeToString(),
process_status))
def _handle_request(self, data: bytes):
# Disable this method to prevent infinite recursion if processing the
# packet happens to send another packet.
if not self._send_responses_on_request:
return
self._send_responses_on_request = False
self._last_request = packets.decode(data)
for packet, status in self._next_packets:
self.assertIs(status, self._client.process_packet(packet))
self._next_packets.clear()
self._send_responses_on_request = True
def _sent_payload(self, message_type):
self.assertIsNotNone(self._last_request)
message = message_type()
message.ParseFromString(self._last_request.payload)
return message
def test_invoke_unary_rpc(self):
method = self._service.SomeUnary.method
for _ in range(3):
self._enqueue_response(1, method, Status.ABORTED,
method.response_type(payload='0_o'))
status, response = self._service.SomeUnary(
method.request_type(magic_number=6))
self.assertEqual(
6,
self._sent_payload(method.request_type).magic_number)
self.assertIs(Status.ABORTED, status)
self.assertEqual('0_o', response.payload)
def test_invoke_unary_rpc_keep_open(self) -> None:
method = self._service.SomeUnary.method
payload_1 = method.response_type(payload='-_-')
payload_2 = method.response_type(payload='0_o')
self._enqueue_response(1, method, Status.ABORTED, payload_1)
replies: list = []
enqueue_replies = lambda _, reply: replies.append(reply)
self._service.SomeUnary.invoke(method.request_type(magic_number=6),
enqueue_replies,
enqueue_replies,
keep_open=True)
self.assertEqual([payload_1, Status.ABORTED], replies)
# Send another packet and make sure it is processed even though the RPC
# terminated.
self._client.process_packet(
packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
channel_id=1,
service_id=method.service.id,
method_id=method.id,
status=Status.OK.value,
payload=payload_2.SerializeToString()).SerializeToString())
self.assertEqual([payload_1, Status.ABORTED, payload_2, Status.OK],
replies)
def test_invoke_unary_rpc_with_callback(self):
method = self._service.SomeUnary.method
for _ in range(3):
self._enqueue_response(1, method, Status.ABORTED,
method.response_type(payload='0_o'))
callback = mock.Mock()
self._service.SomeUnary.invoke(self._request(magic_number=5),
callback, callback)
callback.assert_has_calls([
mock.call(_rpc(self._service.SomeUnary),
method.response_type(payload='0_o')),
mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
])
self.assertEqual(
5,
self._sent_payload(method.request_type).magic_number)
def test_unary_rpc_server_error(self):
method = self._service.SomeUnary.method
for _ in range(3):
self._enqueue_error(1, method, Status.NOT_FOUND)
with self.assertRaises(callback_client.RpcError) as context:
self._service.SomeUnary(method.request_type(magic_number=6))
self.assertIs(context.exception.status, Status.NOT_FOUND)
def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
stub = self._service.SomeUnary
self._enqueue_response(1, stub.method)
exception_msg = 'YOU BROKE IT O-]-<'
with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
stub.invoke(self._request(),
mock.Mock(side_effect=Exception(exception_msg)))
self.assertIn(exception_msg, ''.join(logs.output))
# Make sure we can still invoke the RPC.
self._enqueue_response(1, stub.method, Status.UNKNOWN)
status, _ = stub()
self.assertIs(status, Status.UNKNOWN)
def test_invoke_unary_rpc_with_callback_cancel(self):
callback = mock.Mock()
for _ in range(3):
call = self._service.SomeUnary.invoke(
self._request(magic_number=55), callback)
self.assertIsNotNone(self._last_request)
self._last_request = None
# Try to invoke the RPC again before cancelling, without overriding
# pending RPCs.
with self.assertRaises(client.Error):
self._service.SomeUnary.invoke(self._request(magic_number=56),
callback,
override_pending=False)
self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) # Already cancelled, returns False
# Unary RPCs do not send a cancel request to the server.
self.assertIsNone(self._last_request)
callback.assert_not_called()
def test_reinvoke_unary_rpc(self):
for _ in range(3):
self._last_request = None
self._service.SomeUnary.invoke(self._request(magic_number=55),
override_pending=True)
self.assertEqual(self._last_request.type,
packet_pb2.PacketType.REQUEST)
def test_invoke_server_streaming(self):
method = self._service.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
rep2 = method.response_type(payload='?')
for _ in range(3):
self._enqueue_server_stream(1, method, rep1)
self._enqueue_server_stream(1, method, rep2)
self._enqueue_response(1, method, Status.ABORTED)
self.assertEqual(
[rep1, rep2],
list(self._service.SomeServerStreaming(magic_number=4)))
self.assertEqual(
4,
self._sent_payload(method.request_type).magic_number)
def test_invoke_server_streaming_with_deprecated_packet_format(self):
method = self._service.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
rep2 = method.response_type(payload='?')
for _ in range(3):
# The original packet format used RESPONSE packets for the server
# stream and a SERVER_STREAM_END packet as the last packet. These
# are converted to SERVER_STREAM packets followed by a RESPONSE.
self._enqueue_response(1, method, payload=rep1)
self._enqueue_response(1, method, payload=rep2)
self._next_packets.append((packet_pb2.RpcPacket(
type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END,
channel_id=1,
service_id=method.service.id,
method_id=method.id,
status=Status.INVALID_ARGUMENT.value).SerializeToString(),
Status.OK))
call = self._service.SomeServerStreaming(magic_number=4)
self.assertEqual([rep1, rep2], list(call))
self.assertIs(call.status, Status.INVALID_ARGUMENT)
self.assertEqual(
4,
self._sent_payload(method.request_type).magic_number)
def test_invoke_server_streaming_with_callbacks(self):
method = self._service.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
rep2 = method.response_type(payload='?')
for _ in range(3):
self._enqueue_server_stream(1, method, rep1)
self._enqueue_server_stream(1, method, rep2)
self._enqueue_response(1, method, Status.ABORTED)
callback = mock.Mock()
self._service.SomeServerStreaming.invoke(
self._request(magic_number=3), callback, callback)
rpc = _rpc(self._service.SomeServerStreaming)
callback.assert_has_calls([
mock.call(rpc, method.response_type(payload='!!!')),
mock.call(rpc, method.response_type(payload='?')),
mock.call(rpc, Status.ABORTED),
])
self.assertEqual(
3,
self._sent_payload(method.request_type).magic_number)
def test_invoke_server_streaming_with_callback_cancel(self):
stub = self._service.SomeServerStreaming
resp = stub.method.response_type(payload='!!!')
self._enqueue_server_stream(1, stub.method, resp)
callback = mock.Mock()
call = stub.invoke(self._request(magic_number=3), callback)
callback.assert_called_once_with(
_rpc(stub), stub.method.response_type(payload='!!!'))
callback.reset_mock()
call.cancel()
self.assertEqual(self._last_request.type, packet_pb2.PacketType.CANCEL)
# Ensure the RPC can be called after being cancelled.
self._enqueue_server_stream(1, stub.method, resp)
self._enqueue_response(1, stub.method, Status.OK)
call = stub.invoke(self._request(magic_number=3), callback, callback)
callback.assert_has_calls([
mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
mock.call(_rpc(stub), Status.OK),
])
def test_ignore_bad_packets_with_pending_rpc(self):
method = self._service.SomeUnary.method
service_id = method.service.id
# Unknown channel
self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
# Bad service
self._enqueue_response(1,
ids=(999, method.id),
process_status=Status.OK)
# Bad method
self._enqueue_response(1,
ids=(service_id, 999),
process_status=Status.OK)
# For RPC not pending (is Status.OK because the packet is processed)
self._enqueue_response(1,
ids=(service_id,
self._service.SomeBidiStreaming.method.id),
process_status=Status.OK)
self._enqueue_response(1, method, process_status=Status.OK)
status, response = self._service.SomeUnary(magic_number=6)
self.assertIs(Status.OK, status)
self.assertEqual('', response.payload)
def test_pass_none_if_payload_fails_to_decode(self):
method = self._service.SomeUnary.method
self._enqueue_response(1,
method,
Status.OK,
b'INVALID DATA!!!',
process_status=Status.OK)
status, response = self._service.SomeUnary(magic_number=6)
self.assertIs(status, Status.OK)
self.assertIsNone(response)
def test_rpc_help_contains_method_name(self):
rpc = self._service.SomeUnary
self.assertIn(rpc.method.full_name, rpc.help())
def test_default_timeouts_set_on_impl(self):
impl = callback_client.Impl(None, 1.5)
self.assertEqual(impl.default_unary_timeout_s, None)
self.assertEqual(impl.default_stream_timeout_s, 1.5)
def test_default_timeouts_set_for_all_rpcs(self):
rpc_client = client.Client.from_modules(callback_client.Impl(
99, 100), [client.Channel(1, lambda *a, **b: None)],
self._protos.modules())
rpcs = rpc_client.channel(1).rpcs
self.assertEqual(
rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
self.assertEqual(
rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
100)
def test_timeout_unary(self):
with self.assertRaises(callback_client.RpcTimeout):
self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
def test_timeout_unary_set_default(self):
self._service.SomeUnary.default_timeout_s = 0.0001
with self.assertRaises(callback_client.RpcTimeout):
self._service.SomeUnary()
def test_timeout_server_streaming_iteration(self):
responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
with self.assertRaises(callback_client.RpcTimeout):
for _ in responses:
pass
def test_timeout_server_streaming_responses(self):
responses = self._service.SomeServerStreaming()
with self.assertRaises(callback_client.RpcTimeout):
for _ in responses.responses(timeout_s=0.0001):
pass
if __name__ == '__main__':
unittest.main()