blob: 504eb88e7edff649b27302d4128c7303baa5a595 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for util.traverse_ir."""
import collections
import unittest
from compiler.util import ir_data
from compiler.util import ir_data_utils
from compiler.util import traverse_ir
_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(
ir_data.EmbossIr,
"""{
"module": [
{
"type": [
{
"structure": {
"field": [
{
"location": {
"start": { "constant": { "value": "0" } },
"size": { "constant": { "value": "8" } }
},
"type": {
"atomic_type": {
"reference": {
"canonical_name": {
"module_file": "",
"object_path": ["UInt"]
}
}
}
},
"name": { "name": { "text": "field1" } }
},
{
"location": {
"start": { "constant": { "value": "8" } },
"size": { "constant": { "value": "16" } }
},
"type": {
"array_type": {
"base_type": {
"atomic_type": {
"reference": {
"canonical_name": {
"module_file": "",
"object_path": ["UInt"]
}
}
}
},
"element_count": { "constant": { "value": "8" } }
}
},
"name": { "name": { "text": "field2" } }
}
]
},
"name": { "name": { "text": "Foo" } },
"subtype": [
{
"structure": {
"field": [
{
"location": {
"start": { "constant": { "value": "24" } },
"size": { "constant": { "value": "32" } }
},
"type": {
"atomic_type": {
"reference": {
"canonical_name": {
"module_file": "",
"object_path": ["UInt"]
}
}
}
},
"name": { "name": { "text": "bar_field1" } }
},
{
"location": {
"start": { "constant": { "value": "32" } },
"size": { "constant": { "value": "320" } }
},
"type": {
"array_type": {
"base_type": {
"array_type": {
"base_type": {
"atomic_type": {
"reference": {
"canonical_name": {
"module_file": "",
"object_path": ["UInt"]
}
}
}
},
"element_count": { "constant": { "value": "16" } }
}
},
"automatic": { }
}
},
"name": { "name": { "text": "bar_field2" } }
}
]
},
"name": { "name": { "text": "Bar" } }
}
]
},
{
"enumeration": {
"value": [
{
"name": { "name": { "text": "ONE" } },
"value": { "constant": { "value": "1" } }
},
{
"name": { "name": { "text": "TWO" } },
"value": {
"function": {
"function": "ADDITION",
"args": [
{ "constant": { "value": "1" } },
{ "constant": { "value": "1" } }
],
"function_name": { "text": "+" }
}
}
}
]
},
"name": { "name": { "text": "Bar" } }
}
],
"source_file_name": "t.emb"
},
{
"type": [
{
"external": { },
"name": {
"name": { "text": "UInt" },
"canonical_name": { "module_file": "", "object_path": ["UInt"] }
},
"attribute": [
{
"name": { "text": "statically_sized" },
"value": { "expression": { "boolean_constant": { "value": true } } }
},
{
"name": { "text": "size_in_bits" },
"value": { "expression": { "constant": { "value": "64" } } }
}
]
}
],
"source_file_name": ""
}
]
}""",
)
def _count_entries(sequence):
counts = collections.Counter()
for entry in sequence:
counts[entry] += 1
return counts
def _record_constant(constant, constant_list):
constant_list.append(int(constant.value))
def _record_field_name_and_constant(constant, constant_list, field):
constant_list.append((field.name.name.text, int(constant.value)))
def _record_file_name_and_constant(constant, constant_list, source_file_name):
constant_list.append((source_file_name, int(constant.value)))
def _record_location_parameter_and_constant(constant, constant_list, location=None):
constant_list.append((location, int(constant.value)))
def _record_kind_and_constant(constant, constant_list, type_definition):
if type_definition.HasField("enumeration"):
constant_list.append(("enumeration", int(constant.value)))
elif type_definition.HasField("structure"):
constant_list.append(("structure", int(constant.value)))
elif type_definition.HasField("external"):
constant_list.append(("external", int(constant.value)))
else:
assert False, "Shouldn't be here."
class TraverseIrTest(unittest.TestCase):
def test_filter_on_type(self):
constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.NumericConstant],
_record_constant,
parameters={"constant_list": constants},
)
self.assertEqual(
_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]),
_count_entries(constants),
)
def test_filter_on_type_in_type(self):
constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.Function, ir_data.Expression, ir_data.NumericConstant],
_record_constant,
parameters={"constant_list": constants},
)
self.assertEqual([1, 1], constants)
def test_filter_on_type_star_type(self):
struct_constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.Structure, ir_data.NumericConstant],
_record_constant,
parameters={"constant_list": struct_constants},
)
self.assertEqual(
_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]),
_count_entries(struct_constants),
)
enum_constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.Enum, ir_data.NumericConstant],
_record_constant,
parameters={"constant_list": enum_constants},
)
self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants))
def test_filter_on_not_type(self):
notstruct_constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.NumericConstant],
_record_constant,
skip_descendants_of=(ir_data.Structure,),
parameters={"constant_list": notstruct_constants},
)
self.assertEqual(
_count_entries([1, 1, 1, 64]), _count_entries(notstruct_constants)
)
def test_field_is_populated(self):
constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.Field, ir_data.NumericConstant],
_record_field_name_and_constant,
parameters={"constant_list": constants},
)
self.assertEqual(
_count_entries(
[
("field1", 0),
("field1", 8),
("field2", 8),
("field2", 8),
("field2", 16),
("bar_field1", 24),
("bar_field1", 32),
("bar_field2", 16),
("bar_field2", 32),
("bar_field2", 320),
]
),
_count_entries(constants),
)
def test_file_name_is_populated(self):
constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.NumericConstant],
_record_file_name_and_constant,
parameters={"constant_list": constants},
)
self.assertEqual(
_count_entries(
[
("t.emb", 0),
("t.emb", 8),
("t.emb", 8),
("t.emb", 8),
("t.emb", 16),
("t.emb", 24),
("t.emb", 32),
("t.emb", 16),
("t.emb", 32),
("t.emb", 320),
("t.emb", 1),
("t.emb", 1),
("t.emb", 1),
("", 64),
]
),
_count_entries(constants),
)
def test_type_definition_is_populated(self):
constants = []
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.NumericConstant],
_record_kind_and_constant,
parameters={"constant_list": constants},
)
self.assertEqual(
_count_entries(
[
("structure", 0),
("structure", 8),
("structure", 8),
("structure", 8),
("structure", 16),
("structure", 24),
("structure", 32),
("structure", 16),
("structure", 32),
("structure", 320),
("enumeration", 1),
("enumeration", 1),
("enumeration", 1),
("external", 64),
]
),
_count_entries(constants),
)
def test_keyword_args_dict_in_action(self):
call_counts = {"populated": 0, "not": 0}
def check_field_is_populated(node, **kwargs):
del node # Unused.
self.assertTrue(kwargs["field"])
call_counts["populated"] += 1
def check_field_is_not_populated(node, **kwargs):
del node # Unused.
self.assertFalse("field" in kwargs)
call_counts["not"] += 1
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated
)
self.assertEqual(7, call_counts["populated"])
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], check_field_is_not_populated
)
self.assertEqual(2, call_counts["not"])
def test_pass_only_to_sub_nodes(self):
constants = []
def pass_location_down(field):
return {
"location": (
int(field.location.start.constant.value),
int(field.location.size.constant.value),
)
}
traverse_ir.fast_traverse_ir_top_down(
_EXAMPLE_IR,
[ir_data.NumericConstant],
_record_location_parameter_and_constant,
incidental_actions={ir_data.Field: pass_location_down},
parameters={"constant_list": constants, "location": None},
)
self.assertEqual(
_count_entries(
[
((0, 8), 0),
((0, 8), 8),
((8, 16), 8),
((8, 16), 8),
((8, 16), 16),
((24, 32), 24),
((24, 32), 32),
((32, 320), 16),
((32, 320), 32),
((32, 320), 320),
(None, 1),
(None, 1),
(None, 1),
(None, 64),
]
),
_count_entries(constants),
)
if __name__ == "__main__":
unittest.main()