blob: 7569cf0ab8e86c48646e4047ac72463ca5107468 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests creating pw_rpc client."""
import unittest
import tempfile
from pathlib import Path
from typing import List, Tuple
from pw_protobuf_compiler import python_protos
from pw_rpc import client, ids, packets
from pw_status import Status
TEST_PROTO_1 = b"""\
syntax = "proto3";
package pw.call.test1;
message SomeMessage {
uint32 magic_number = 1;
}
message AnotherMessage {
enum Result {
FAILED = 0;
FAILED_MISERABLY = 1;
I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
}
Result result = 1;
string payload = 2;
}
service PublicService {
rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
}
"""
TEST_PROTO_2 = b"""\
syntax = "proto2";
package pw.call.test2;
message Request {
optional float magic_number = 1;
}
message Response {
}
service Alpha {
rpc Unary(Request) returns (Response) {}
}
service Bravo {
rpc BidiStreaming(stream Request) returns (stream Response) {}
}
"""
class ClientTest(unittest.TestCase):
"""Tests the pw_rpc Python client."""
def setUp(self):
self._proto_dir = tempfile.TemporaryDirectory(prefix='proto_test')
protos = []
for i, contents in enumerate([TEST_PROTO_1, TEST_PROTO_2]):
protos.append(Path(self._proto_dir.name, f'test_{i}.proto'))
protos[-1].write_bytes(contents)
self._protos = python_protos.Library(
python_protos.compile_and_import(protos))
self._impl = client.SimpleSynchronousClient()
self._client = client.Client.from_modules(
self._impl, [client.Channel(1, self._handle_request)],
self._protos.modules())
self._last_request: packets.RpcPacket = None
self._next_packets: List[Tuple[bytes, bool]] = []
def tearDown(self):
self._proto_dir.cleanup()
def _enqueue_response(self,
channel_id: int,
service_id: int,
method_id: int,
status: Status = Status.OK,
response=b'',
valid=True):
if isinstance(response, bytes):
payload = response
else:
payload = response.SerializeToString()
self._next_packets.append(
(packets.RpcPacket(channel_id=channel_id,
service_id=service_id,
method_id=method_id,
status=status.value,
payload=payload).SerializeToString(), valid))
def _handle_request(self, data: bytes):
self._last_request = packets.decode(data)
self.assertTrue(self._next_packets)
for packet, valid in self._next_packets:
self.assertEqual(valid, self._client.process_packet(packet))
def _last_payload(self, message_type):
self.assertIsNotNone(self._last_request)
message = message_type()
message.ParseFromString(self._last_request.payload)
return message
def test_access_service_client_as_attribute_or_index(self):
self.assertIs(
self._client.channel(1).call.PublicService,
self._client.channel(1).call['PublicService'])
self.assertIs(
self._client.channel(1).call.PublicService,
self._client.channel(1).call[ids.calculate('PublicService')])
def test_access_method_client_as_attribute_or_index(self):
self.assertIs(
self._client.channel(1).call.Alpha.Unary,
self._client.channel(1).call['Alpha']['Unary'])
self.assertIs(
self._client.channel(1).call.Alpha.Unary,
self._client.channel(1).call['Alpha'][ids.calculate('Unary')])
def test_check_for_presence_of_services(self):
self.assertIn('PublicService', self._client.channel(1).call)
self.assertIn(ids.calculate('PublicService'),
self._client.channel(1).call)
self.assertNotIn('NotAService', self._client.channel(1).call)
self.assertNotIn(-1213, self._client.channel(1).call)
def test_check_for_presence_of_methods(self):
self.assertIn('SomeUnary', self._client.channel(1).call.PublicService)
self.assertIn(ids.calculate('SomeUnary'),
self._client.channel(1).call.PublicService)
self.assertNotIn('Unary', self._client.channel(1).call.PublicService)
self.assertNotIn(12345, self._client.channel(1).call.PublicService)
def test_invoke_unary_rpc(self):
rpcs = self._client.channel(1).call
method = rpcs.PublicService.SomeUnary.method
for _ in range(3):
self._enqueue_response(1, method.service.id, method.id,
Status.ABORTED,
method.response_type(payload='0_o'))
status, response = rpcs.PublicService.SomeUnary(magic_number=6)
self.assertEqual(
6,
self._last_payload(method.request_type).magic_number)
self.assertIs(Status.ABORTED, status)
self.assertEqual('0_o', response.payload)
def test_ignore_bad_packets_with_pending_rpc(self):
rpcs = self._client.channel(1).call
method = rpcs.PublicService.SomeUnary.method
service_id = method.service.id
# Unknown channel
self._enqueue_response(999, service_id, method.id, valid=False)
# Bad service
self._enqueue_response(1, 999, method.id, valid=False)
# Bad method
self._enqueue_response(1, service_id, 999, valid=False)
# For RPC not pending (valid=True because the packet is processed)
self._enqueue_response(1,
service_id,
rpcs.PublicService.SomeBidiStreaming.method.id,
valid=True)
self._enqueue_response(1, service_id, method.id, valid=True)
status, response = rpcs.PublicService.SomeUnary(magic_number=6)
self.assertIs(Status.OK, status)
self.assertEqual('', response.payload)
def test_pass_none_if_payload_fails_to_decode(self):
rpcs = self._client.channel(1).call
method = rpcs.PublicService.SomeUnary.method
self._enqueue_response(1,
method.service.id,
method.id,
Status.OK,
b'INVALID DATA!!!',
valid=True)
status, response = rpcs.PublicService.SomeUnary(magic_number=6)
self.assertIs(status, Status.OK)
self.assertIsNone(response)
def test_call_method_with_both_message_and_kwargs(self):
req = self._client.services['Alpha'].methods['Unary'].request_type()
with self.assertRaisesRegex(TypeError, r'either'):
self._client.channel(1).call.Alpha.Unary(req, magic_number=1.0)
def test_call_method_with_wrong_type(self):
with self.assertRaisesRegex(TypeError, r'pw\.call\.test2\.Request'):
self._client.channel(1).call.Alpha.Unary('This is str!')
def test_call_method_with_incorrect_message_type(self):
msg = self._protos.packages.pw.call.test1.AnotherMessage()
with self.assertRaisesRegex(TypeError,
r'pw\.call\.test1\.SomeMessage'):
self._client.channel(1).call.PublicService.SomeUnary(msg)
def test_process_packet_invalid_proto_data(self):
self.assertFalse(self._client.process_packet(b'NOT a packet!'))
def test_process_packet_unrecognized_channel(self):
self.assertFalse(
self._client.process_packet(
packets.encode((123, 456, 789),
self._protos.packages.pw.call.test2.Request())))
def test_process_packet_unrecognized_service(self):
self.assertFalse(
self._client.process_packet(
packets.encode((1, 456, 789),
self._protos.packages.pw.call.test2.Request())))
def test_process_packet_unrecognized_method(self):
self.assertFalse(
self._client.process_packet(
packets.encode((1, next(iter(self._client.services)).id, 789),
self._protos.packages.pw.call.test2.Request())))
if __name__ == '__main__':
unittest.main()