blob: 36285db7112762932c3ab373beaeb55acf984b4e [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests the pw_rpc.console_tools.console module."""
import types
import unittest
import pw_status
from pw_protobuf_compiler import python_protos
import pw_rpc
from pw_rpc import callback_client
from pw_rpc.console_tools.console import (CommandHelper, Context, ClientInfo,
alias_deprecated_command)
class TestCommandHelper(unittest.TestCase):
def setUp(self) -> None:
self._commands = {'command_a': 'A', 'command_B': 'B'}
self._variables = {'hello': 1, 'world': 2}
self._helper = CommandHelper(self._commands, self._variables,
'The header', 'The footer')
def test_help_contents(self) -> None:
help_contents = self._helper.help()
self.assertTrue(help_contents.startswith('The header'))
self.assertIn('The footer', help_contents)
for var_name in self._variables:
self.assertIn(var_name, help_contents)
for cmd_name in self._commands:
self.assertIn(cmd_name, help_contents)
def test_repr_is_help(self):
self.assertEqual(repr(self._helper), self._helper.help())
_PROTO = """\
syntax = "proto3";
package the.pkg;
message SomeMessage {
uint32 magic_number = 1;
message AnotherMessage {
string payload = 1;
}
}
service Service {
rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
}
"""
class TestConsoleContext(unittest.TestCase):
"""Tests console_tools.console.Context."""
def setUp(self) -> None:
self._protos = python_protos.Library.from_strings(_PROTO)
self._info = ClientInfo(
'the_client', object(),
pw_rpc.Client.from_modules(callback_client.Impl(), [
pw_rpc.Channel(1, lambda _: None),
pw_rpc.Channel(2, lambda _: None),
], self._protos.modules()))
def test_sets_expected_variables(self) -> None:
variables = Context([self._info],
default_client=self._info.client,
protos=self._protos).variables()
self.assertIn('set_target', variables)
self.assertIsInstance(variables['help'], CommandHelper)
self.assertIs(variables['python_help'], help)
self.assertIs(pw_status.Status, variables['Status'])
self.assertIs(self._info.client, variables['the_client'])
def test_set_target_switches_between_clients(self) -> None:
client_1_channel = self._info.rpc_client.channel(1).channel
client_2_channel = pw_rpc.Channel(99, lambda _: None)
info_2 = ClientInfo(
'other_client', object(),
pw_rpc.Client.from_modules(callback_client.Impl(),
[client_2_channel],
self._protos.modules()))
context = Context([self._info, info_2],
default_client=self._info.client,
protos=self._protos)
# Make sure the RPC service switches from one client to the other.
self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
client_1_channel)
context.set_target(info_2.client)
self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
client_2_channel)
def test_default_client_must_be_in_clients(self) -> None:
with self.assertRaises(ValueError):
Context([self._info],
default_client='something else',
protos=self._protos)
def test_set_target_invalid_channel(self) -> None:
context = Context([self._info],
default_client=self._info.client,
protos=self._protos)
with self.assertRaises(KeyError):
context.set_target(self._info.client, 100)
def test_set_target_non_default_channel(self) -> None:
channel_1 = self._info.rpc_client.channel(1).channel
channel_2 = self._info.rpc_client.channel(2).channel
context = Context([self._info],
default_client=self._info.client,
protos=self._protos)
variables = context.variables()
self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)
context.set_target(self._info.client, 2)
self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)
with self.assertRaises(KeyError):
context.set_target(self._info.client, 100)
def test_set_target_requires_client_object(self) -> None:
context = Context([self._info],
default_client=self._info.client,
protos=self._protos)
with self.assertRaises(ValueError):
context.set_target(self._info.rpc_client)
context.set_target(self._info.client)
def test_derived_context(self) -> None:
called_derived_set_target = False
class DerivedContext(Context):
def set_target(self,
unused_selected_client,
unused_channel_id: int = None) -> None:
nonlocal called_derived_set_target
called_derived_set_target = True
variables = DerivedContext(client_info=[self._info],
default_client=self._info.client,
protos=self._protos).variables()
variables['set_target'](self._info.client)
self.assertTrue(called_derived_set_target)
class TestAliasDeprecatedCommand(unittest.TestCase):
def test_wraps_command_to_new_package(self) -> None:
variables = {'abc': types.SimpleNamespace(command=lambda: 123)}
alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command')
self.assertEqual(variables['xyz'].one.two.three(), 123)
def test_wraps_command_to_existing_package(self) -> None:
variables = {
'abc': types.SimpleNamespace(NewCmd=lambda: 456),
'one': types.SimpleNamespace(),
}
alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
self.assertEqual(variables['one'].two.OldCmd(), 456)
def test_error_if_new_command_does_not_exist(self) -> None:
variables = {
'abc': types.SimpleNamespace(),
'one': types.SimpleNamespace(),
}
with self.assertRaises(AttributeError):
alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
if __name__ == '__main__':
unittest.main()