| # Copyright 2024 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Tests for util.ir_data_fields.""" |
| |
| import dataclasses |
| import enum |
| import sys |
| from typing import Optional |
| import unittest |
| |
| from compiler.util import ir_data |
| from compiler.util import ir_data_fields |
| |
| |
| class TestEnum(enum.Enum): |
| """Used to test python Enum handling.""" |
| |
| UNKNOWN = 0 |
| VALUE_1 = 1 |
| VALUE_2 = 2 |
| |
| |
| @dataclasses.dataclass |
| class Opaque(ir_data.Message): |
| """Used for testing data field helpers""" |
| |
| |
| @dataclasses.dataclass |
| class ClassWithUnion(ir_data.Message): |
| """Used for testing data field helpers""" |
| |
| opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") |
| integer: Optional[int] = ir_data_fields.oneof_field("type") |
| boolean: Optional[bool] = ir_data_fields.oneof_field("type") |
| enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") |
| non_union_field: int = 0 |
| |
| |
| @dataclasses.dataclass |
| class ClassWithTwoUnions(ir_data.Message): |
| """Used for testing data field helpers""" |
| |
| opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") |
| integer: Optional[int] = ir_data_fields.oneof_field("type_1") |
| boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") |
| enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") |
| non_union_field: int = 0 |
| seq_field: list[int] = ir_data_fields.list_field(int) |
| |
| |
| @dataclasses.dataclass |
| class NestedClass(ir_data.Message): |
| """Used for testing data field helpers""" |
| |
| one_union_class: Optional[ClassWithUnion] = None |
| two_union_class: Optional[ClassWithTwoUnions] = None |
| |
| |
| @dataclasses.dataclass |
| class ListCopyTestClass(ir_data.Message): |
| """Used to test behavior or extending a sequence.""" |
| |
| non_union_field: int = 0 |
| seq_field: list[int] = ir_data_fields.list_field(int) |
| |
| |
| @dataclasses.dataclass |
| class OneofFieldTest(ir_data.Message): |
| """Basic test class for oneof fields""" |
| |
| int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1") |
| int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1") |
| normal_field: bool = True |
| |
| |
| class OneOfTest(unittest.TestCase): |
| """Tests for the the various oneof field helpers""" |
| |
| def test_field_attribute(self): |
| """Test the `oneof_field` helper.""" |
| test_field = ir_data_fields.oneof_field("type_1") |
| self.assertIsNotNone(test_field) |
| self.assertTrue(test_field.init) |
| self.assertIsInstance(test_field.default, ir_data_fields.OneOfField) |
| self.assertEqual(test_field.metadata.get("oneof"), "type_1") |
| |
| def test_init_default(self): |
| """Test creating an instance with default fields""" |
| one_of_field_test = OneofFieldTest() |
| self.assertIsNone(one_of_field_test.int_field_1) |
| self.assertIsNone(one_of_field_test.int_field_2) |
| self.assertTrue(one_of_field_test.normal_field) |
| |
| def test_init(self): |
| """Test creating an instance with non-default fields""" |
| one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) |
| self.assertEqual(one_of_field_test.int_field_1, 10) |
| self.assertIsNone(one_of_field_test.int_field_2) |
| self.assertFalse(one_of_field_test.normal_field) |
| |
| def test_set_oneof_field(self): |
| """Tests setting oneof fields causes others in the group to be unset""" |
| one_of_field_test = OneofFieldTest() |
| one_of_field_test.int_field_1 = 10 |
| self.assertEqual(one_of_field_test.int_field_1, 10) |
| self.assertEqual(one_of_field_test.int_field_2, None) |
| one_of_field_test.int_field_2 = 20 |
| self.assertEqual(one_of_field_test.int_field_1, None) |
| self.assertEqual(one_of_field_test.int_field_2, 20) |
| |
| # Do it again |
| one_of_field_test.int_field_1 = 10 |
| self.assertEqual(one_of_field_test.int_field_1, 10) |
| self.assertEqual(one_of_field_test.int_field_2, None) |
| one_of_field_test.int_field_2 = 20 |
| self.assertEqual(one_of_field_test.int_field_1, None) |
| self.assertEqual(one_of_field_test.int_field_2, 20) |
| |
| # Now create a new instance and make sure changes to it are not reflected |
| # on the original object. |
| one_of_field_test_2 = OneofFieldTest() |
| one_of_field_test_2.int_field_1 = 1000 |
| self.assertEqual(one_of_field_test_2.int_field_1, 1000) |
| self.assertEqual(one_of_field_test_2.int_field_2, None) |
| self.assertEqual(one_of_field_test.int_field_1, None) |
| self.assertEqual(one_of_field_test.int_field_2, 20) |
| |
| def test_set_to_none(self): |
| """Tests explicitly setting a oneof field to None""" |
| one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) |
| self.assertEqual(one_of_field_test.int_field_1, 10) |
| self.assertIsNone(one_of_field_test.int_field_2) |
| self.assertFalse(one_of_field_test.normal_field) |
| |
| # Clear the set fields |
| one_of_field_test.int_field_1 = None |
| self.assertIsNone(one_of_field_test.int_field_1) |
| self.assertIsNone(one_of_field_test.int_field_2) |
| self.assertFalse(one_of_field_test.normal_field) |
| |
| # Set another field |
| one_of_field_test.int_field_2 = 200 |
| self.assertIsNone(one_of_field_test.int_field_1) |
| self.assertEqual(one_of_field_test.int_field_2, 200) |
| self.assertFalse(one_of_field_test.normal_field) |
| |
| # Clear the already unset field |
| one_of_field_test.int_field_1 = None |
| self.assertIsNone(one_of_field_test.int_field_1) |
| self.assertEqual(one_of_field_test.int_field_2, 200) |
| self.assertFalse(one_of_field_test.normal_field) |
| |
| def test_oneof_specs(self): |
| """Tests the `oneof_field_specs` filter""" |
| expected = { |
| "int_field_1": ir_data_fields.make_field_spec( |
| "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" |
| ), |
| "int_field_2": ir_data_fields.make_field_spec( |
| "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" |
| ), |
| } |
| actual = ir_data_fields.IrDataclassSpecs.get_specs( |
| OneofFieldTest |
| ).oneof_field_specs |
| self.assertDictEqual(actual, expected) |
| |
| def test_oneof_mappings(self): |
| """Tests the `oneof_mappings` function""" |
| expected = (("int_field_1", "type_1"), ("int_field_2", "type_1")) |
| actual = ir_data_fields.IrDataclassSpecs.get_specs( |
| OneofFieldTest |
| ).oneof_mappings |
| self.assertTupleEqual(actual, expected) |
| |
| |
| class IrDataFieldsTest(unittest.TestCase): |
| """Tests misc methods in ir_data_fields""" |
| |
| def test_copy(self): |
| """Tests copying a data class works as expected""" |
| union = ClassWithTwoUnions( |
| opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3] |
| ) |
| nested_class = NestedClass(two_union_class=union) |
| nested_class_copy = ir_data_fields.copy(nested_class) |
| self.assertIsNotNone(nested_class_copy) |
| self.assertIsNot(nested_class, nested_class_copy) |
| self.assertEqual(nested_class_copy, nested_class) |
| |
| empty_copy = ir_data_fields.copy(None) |
| self.assertIsNone(empty_copy) |
| |
| def test_copy_values_list(self): |
| """Tests that CopyValuesList copies values""" |
| data_list = ir_data_fields.CopyValuesList(ListCopyTestClass) |
| self.assertEqual(len(data_list), 0) |
| |
| list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7]) |
| list_tests = [ir_data_fields.copy(list_test) for _ in range(4)] |
| data_list.extend(list_tests) |
| self.assertEqual(len(data_list), 4) |
| for i in data_list: |
| self.assertEqual(i, list_test) |
| |
| def test_list_param_is_copied(self): |
| """Test that lists passed to constructors are converted to CopyValuesList""" |
| seq_field = [5, 6, 7] |
| list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field) |
| self.assertEqual(len(list_test.seq_field), len(seq_field)) |
| self.assertIsNot(list_test.seq_field, seq_field) |
| self.assertEqual(list_test.seq_field, seq_field) |
| self.assertTrue( |
| isinstance(list_test.seq_field, ir_data_fields.CopyValuesList) |
| ) |
| |
| def test_copy_oneof(self): |
| """Tests copying an IR data class that has oneof fields.""" |
| oneof_test = OneofFieldTest() |
| oneof_test.int_field_1 = 10 |
| oneof_test.normal_field = False |
| self.assertEqual(oneof_test.int_field_1, 10) |
| self.assertEqual(oneof_test.normal_field, False) |
| |
| oneof_copy = ir_data_fields.copy(oneof_test) |
| self.assertIsNotNone(oneof_copy) |
| self.assertEqual(oneof_copy.int_field_1, 10) |
| self.assertIsNone(oneof_copy.int_field_2) |
| self.assertEqual(oneof_copy.normal_field, False) |
| |
| oneof_copy.int_field_2 = 100 |
| self.assertEqual(oneof_copy.int_field_2, 100) |
| self.assertIsNone(oneof_copy.int_field_1) |
| self.assertEqual(oneof_test.int_field_1, 10) |
| self.assertEqual(oneof_test.normal_field, False) |
| |
| |
| ir_data_fields.cache_message_specs( |
| sys.modules[OneofFieldTest.__module__], ir_data.Message) |
| |
| if __name__ == "__main__": |
| unittest.main() |